# Feature Extraction using YOLOv9

In [None]:
import os
import shutil
import subprocess
import re

def run_yolov9_detect(model, base_dir, video):
    """Run the yolov9 detect.py script with specified parameters."""
    command = [
        'python3', 'yolov9-main/detect.py', 
        '--source', f'{os.path.join(base_dir, video)}', 
        '--img', '640', 
        '--weights', f'yolov9-main/yolov9-{model}-converted.pt', 
        '--name', f'{video}_{model}', 
        '--save-txt', 
        '--save-conf', 
        '--save-crop',
        '--nosave'
    ]

    try:
        result = subprocess.run(command, check=True, capture_output=True, text=True)
        print(f"Script output:\n{result.stdout}")
    except subprocess.CalledProcessError as e:
        print(f"Script failed with error:\n{e.stderr}")
        
base_dir = 'keyframes_'
models = ['c', 'e']
# Create a list of directories in the keyframes folder
directories = [f for f in os.listdir(base_dir)]
directories.remove('.DS_Store')
print(directories)

# Run detection for each video and model
for video in directories:
    for model in models:
        run_yolov9_detect(model, base_dir, video)


# Image Analysis and Description using Detectron2 and Vision-Transformer Models

### 1. Object Detection and Segmentation
- **Detectron2**: A  object detection library by Facebook AI Research (FAIR).
  - **Model**: Utilizes a pre-trained Mask R-CNN model for detecting and segmenting objects within an image.
  - **Configuration**: The model can be configured to run on either CPU or GPU, depending on the availability of a CUDA-enabled GPU.

### 2. Image Captioning
- **Transformers Library**: A  library by Hugging Face for natural language processing and vision tasks.
  - **Model**: Uses the `VisionEncoderDecoderModel`, which combines a vision transformer (ViT) encoder with a GPT-2 decoder.
  - **Processor**: The `ViTImageProcessor` for preprocessing images and `AutoTokenizer` for handling the text generation.


1. **Image Loading**:
   - Load an image using the `PIL` library.

2. **Object Detection and Segmentation**:
   - Configure and use the Detectron2 model to detect objects and generate segmentation masks.
   - Visualize the segmentation results and extract instances (detected objects) from the image.

3. **Image and Object Description**:
   - Generate a detailed description of the entire image, focusing on the background.
   - For each detected object, generate specific descriptions by focusing on their bounding boxes.

#### Results

- **Background Description**: Provides a textual description of the overall background of the image.
- **Object Descriptions**: Generates individual descriptions for each detected object, detailing their appearance, actions, or other relevant attributes.


In [None]:
import cv2
import torch
import numpy as np
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

# Function to load and preprocess the image
def load_image(image_path):
    image = Image.open(image_path)
    return image

# Function to detect objects in the image
def detect_objects(image_path, device):
    # Configure Detectron2
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
    cfg.MODEL.DEVICE = device  # Use CPU or GPU
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
    predictor = DefaultPredictor(cfg)

    # Read the image
    image = cv2.imread(image_path)
    outputs = predictor(image)

    # Visualize the predictions
    v = Visualizer(image[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
    out = v.draw_instance_predictions(outputs["instances"].to(device))
    segmented_image = out.get_image()[:, :, ::-1]

    return outputs["instances"], segmented_image

# Function to describe the background or specific object
def describe_image(image, focus_area=None, device='cpu'):
    # Load the model and processor
    model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
    processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

    # Preprocess the image
    if focus_area:
        image = image.crop(focus_area)

    pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)

    # Generate the caption
    output_ids = model.generate(pixel_values, max_length=50, num_beams=4, eos_token_id=tokenizer.eos_token_id)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return caption

# Example usage
device = "cuda" if torch.cuda.is_available() else "cpu"
image_path = 'keyframes/00176/00176_Scene-5.jpg'

# Load and detect objects in the image
image = load_image(image_path)
instances, segmented_image = detect_objects(image_path, device)

# Describe the entire image (background)
background_description = describe_image(image, device=device)
print("Background description:", background_description)

# Describe specific objects
# for i in range(len(instances)):
#     bbox = instances.pred_boxes[i].tensor.cpu().numpy()[0]
#     focus_area = (bbox[0], bbox[1], bbox[2], bbox[3])
#     object_description = describe_image(image, focus_area, device=device)
#     print(f"Object {i+1} description:", object_description)

# Optionally, save the segmented image
# segmented_image_pil = Image.fromarray(segmented_image)
# segmented_image_pil.save("segmented_image.jpg")


# Color Detection and/or Keyframe Description using BLIP

In [None]:
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch
import cv2
import re

# Load the BLIP model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

def generate_caption(image_path, prompt="Describe the colors."):
    raw_image = cv2.imread(image_path)
    raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
    
    # Prepare image and prompt for BLIP
    inputs = processor(images=raw_image, return_tensors="pt")
    
    # Generate caption
    out = model.generate(**inputs)
    caption = processor.decode(out[0], skip_special_tokens=True)
    return caption

def extract_colors(caption):
    colors = ["red", "green", "blue", "yellow", "white", "black", "orange", "purple", "brown", "gray", "pink"]
    found_colors = [color for color in colors if re.search(r'\b' + color + r'\b', caption)]
    return found_colors


# Load your cropped image
image_path = 'yolov9-main/runs/detect/yolov9_640_detect3/crops/car/07.jpg'
target_label = 'backpack'  # Replace with the target label (e.g., 'boat', 'backpack')

# For description of keyframes use the corresponding path
# image_path = 'keyframes/00101/00101.mp4_start_6894.jpg'  # Replace with the path to your image

# Generate caption for the image
caption = generate_caption(image_path)
print("Generated Caption:", caption)

# Extract colors from the caption. Comment this out if keyframe description is needed
colors_in_caption = extract_colors(caption)
print("Colors found in caption:", colors_in_caption)


# Dominant Color Detection with Enhanced Vibrance and Saturation

This script processes an image to identify the two most dominant colors, emphasizing vibrant colors over less vibrant ones like gray and black. The key steps involved in the process are:

1. **Image Preprocessing**:
   - **Saturation Adjustment**: Increases the saturation of all colors in the image to make colors more vivid.
   - **Vibrance Adjustment**: Boosts the vibrance of colors, particularly those that are less saturated, ensuring more nuanced color enhancement.

2. **Color Filtering**:
   - Converts the image to the HSV color space to filter out low-saturation colors, reducing the likelihood of selecting colors like gray or black as dominant colors.

3. **K-Means Clustering**:
   - Uses K-means clustering to identify the most common colors in the preprocessed image. The algorithm ensures that the top colors are distinct by checking the Euclidean distance between colors.

4. **Color Preference Weighting**:
   - Applies a weighting mechanism to prioritize vibrant colors (e.g., red, green, blue, yellow, orange) over less vibrant ones during the color matching process.

5. **Output**:
   - Ensures that the two most dominant colors are unique and outputs their names and RGB values. The original and preprocessed images are displayed side-by-side to visualize the effect of the adjustments.

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import webcolors
from scipy.spatial.distance import cdist

# Define a dictionary of common colors with an additional preference weight for each color
COMMON_COLORS = {
    'red': ('#FF0000', 10),
    'green': ('#008000', 10),
    'blue': ('#0000FF', 10),
    'yellow': ('#FFFF00', 10),
    'black': ('#000000', 1),
    'white': ('#FFFFFF', 10),
    # 'grey': ('#808080', 1),
    'orange': ('#FFA500', 8)
    # 'pink': ('#FFC0CB', 5),
    # 'purple': ('#800080', 5),
    # 'brown': ('#A52A2A', 3)
}

def closest_color(requested_color):
    min_colors = {}
    for name, (hex_code, weight) in COMMON_COLORS.items():
        r_c, g_c, b_c = webcolors.hex_to_rgb(hex_code)
        rd = (r_c - requested_color[0]) ** 2
        gd = (g_c - requested_color[1]) ** 2
        bd = (b_c - requested_color[2]) ** 2
        min_colors[(rd + gd + bd) / weight] = name
    return min_colors[min(min_colors.keys())]

def adjust_saturation(image, saturation_scale=2.0):
    """
    Adjust the saturation of the image.
    """
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
    hsv[..., 1] *= saturation_scale
    hsv[..., 1] = np.clip(hsv[..., 1], 0, 255)
    return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)

def adjust_vibrance(image, vibrance_scale=2.0):
    """
    Adjust the vibrance of the image.
    """
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
    saturation = hsv[..., 1]
    mean_saturation = np.mean(saturation)
    increase = (1 - (saturation / 255.0)) * (saturation - mean_saturation) * vibrance_scale
    hsv[..., 1] = np.clip(saturation + increase, 0, 255)
    return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)

def preprocess_image(image):
    # Adjust saturation
    image = adjust_saturation(image, saturation_scale=3.0)
    
    # Adjust vibrance
    image = adjust_vibrance(image, vibrance_scale=3.0)
    
    return image

def get_dominant_colors(image_path, k=10, top_n=2, min_distance=50, saturation_threshold=50):
    # Load image and convert to RGB
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Preprocess the image
    preprocessed_image = preprocess_image(image)
    
    # Reshape the image to be a list of pixels
    pixels = preprocessed_image.reshape((-1, 3))
    
    # Convert to HSV to filter by saturation
    hsv_pixels = cv2.cvtColor(pixels.reshape(-1, 1, 3).astype(np.uint8), cv2.COLOR_RGB2HSV).reshape(-1, 3)
    pixels = pixels[hsv_pixels[:, 1] > saturation_threshold]

    # Perform K-means clustering to find the most common colors
    kmeans = KMeans(n_clusters=k, random_state=42)
    kmeans.fit(pixels)
    
    # Find the most dominant clusters
    counts = np.bincount(kmeans.labels_)
    dominant_indices = np.argsort(-counts)
    dominant_colors = kmeans.cluster_centers_[dominant_indices]
    
    # Ensure the top N colors are distinct and vibrant
    distinct_colors = []
    for color in dominant_colors:
        if len(distinct_colors) == 0:
            distinct_colors.append(color)
        else:
            if all(cdist([color], [distinct_color], metric='euclidean')[0][0] > min_distance for distinct_color in distinct_colors):
                distinct_colors.append(color)
            if len(distinct_colors) >= top_n:
                break
    
    # Convert the dominant colors to human-readable names
    dominant_color_names = [closest_color(color) for color in distinct_colors]
    
    # Ensure no duplicates in the top N colors
    distinct_color_names = []
    distinct_colors_filtered = []
    for name, color in zip(dominant_color_names, distinct_colors):
        if name not in distinct_color_names:
            distinct_color_names.append(name)
            distinct_colors_filtered.append(color)
    
    # If we don't have enough distinct colors, pick more from the list ensuring uniqueness
    if len(distinct_color_names) < top_n:
        for color in dominant_colors[len(distinct_colors):]:
            name = closest_color(color)
            if name not in distinct_color_names:
                distinct_color_names.append(name)
                distinct_colors_filtered.append(color)
            if len(distinct_color_names) >= top_n:
                break
    
    return distinct_color_names[:top_n], distinct_colors_filtered[:top_n], preprocessed_image

# Load your cropped image
image_path = 'yolov9-main/runs/detect/00120_e/crops/backpack/00120_Scene-562.jpg'

# Get the dominant colors in the image
dominant_color_names, dominant_colors, preprocessed_image = get_dominant_colors(image_path)
print("Dominant Color Names:", dominant_color_names)
print("Dominant Colors RGB:", dominant_colors)

# Display the original image and the image with increased vibrance and saturation
original_image = cv2.imread(image_path)
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(preprocessed_image)
plt.title(f"Preprocessed Image\n(Dominant Colors: {dominant_color_names})")
plt.axis('off')

plt.show()


# OCR using EasyOCR

In [None]:
import cv2
import easyocr
import difflib
import json

def ocr_video(video_path, output_file, similarity_threshold=0.5, frame_step=10):
    # Initialize the video capture and OCR reader
    cap = cv2.VideoCapture(video_path)
    reader = easyocr.Reader(['en'])

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    # Initialize variables to store the results
    results = []

    prev_text = ""
    current_block = None

    for i in range(0, frame_count, frame_step):
        ret, frame = cap.read()
        if not ret:
            break
        
        # Skip frames until we reach the next frame of interest
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        
        # Calculate the timestamp for the current frame
        timestamp = i / fps

        # Convert the frame to RGB (easyocr works on RGB images)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Perform OCR on the frame
        ocr_result_list = reader.readtext(frame_rgb, detail=0)
        ocr_result = ' '.join(ocr_result_list).strip().lower()

        # Skip frames with no text
        if not ocr_result:
            continue

        # Calculate similarity with the previous text
        similarity = difflib.SequenceMatcher(None, prev_text, ocr_result).ratio()

        if similarity >= similarity_threshold:
            # If the text is similar enough, update the end time of the current block
            if current_block:
                current_block['end_time'] = timestamp + (frame_step / fps)
            else:
                current_block = {
                    'text': ocr_result,
                    'start_time': timestamp,
                    'end_time': timestamp + (frame_step / fps)
                }
        else:
            # If the text is different enough, finalize the current block and start a new one
            if current_block:
                results.append(current_block)
            current_block = {
                'text': ocr_result,
                'start_time': timestamp,
                'end_time': timestamp + (frame_step / fps)
            }
        prev_text = ocr_result

    # Finalize the last block
    if current_block:
        results.append(current_block)

    cap.release()

    # Save results to a JSON file
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=4)

    return results

# Example usage
video_path = 'preprocessed_videos/00102/00102.mp4'
output_file = 'ocr_results.json'
similarity_threshold = 0.5
frame_step = 10
ocr_results = ocr_video(video_path, output_file, similarity_threshold, frame_step)

# for result in ocr_results:
#     print(f"Text: {result['text']}, Start Time: {result['start_time']:.2f}, End Time: {result['end_time']:.2f}")


# Search for a string in OCR results

In [None]:
import json
import difflib

def compare_string_with_ocr_results(ocr_file, input_string, similarity_threshold=0.8):
    # Load OCR results from the file
    with open(ocr_file, 'r') as f:
        ocr_results = json.load(f)

    # Normalize the input string
    input_string = input_string.strip().lower()

    matching_results = []

    for result in ocr_results:
        ocr_text = result['text']
        similarity = difflib.SequenceMatcher(None, ocr_text, input_string).ratio()

        if similarity >= similarity_threshold:
            matching_results.append({
                'text': ocr_text,
                'start_time': result['start_time'],
                'end_time': result['end_time']
            })

    return matching_results

# Example usage
ocr_file = 'ocr_results.json'
input_string = 'Get reliable diving gear'
similarity_threshold = 0.6

matching_results = compare_string_with_ocr_results(ocr_file, input_string, similarity_threshold)

for result in matching_results:
    print(f"Matching Text: {result['text']}, Start Time: {result['start_time']:.2f}, End Time: {result['end_time']:.2f}")


# Using CLIP

In [None]:
import os
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torchvision.transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from sklearn.metrics import confusion_matrix, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
import xml.etree.ElementTree as ET

# Load models and processors
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
faster_rcnn = fasterrcnn_resnet50_fpn(pretrained=True).to(device)
faster_rcnn.eval()

# Define the folder containing images
frame_dir = '00110'
annotation_dir = '00110'

# Define the queries for evaluation
queries = ["A photo of a person", "A photo of a bird", "A photo of a truck", "A photo of a horse", "A photo of a car"]
query_labels = ["person", "bird", "truck", "horse", "car"]

# Mapping to unify bus and train as truck
class_mapping = {
    "bus": "truck",
    "train": "truck"
}

def get_ground_truth_labels(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    labels = []
    boxes = []
    for member in root.findall('object'):
        label = member.find('name').text
        if label in class_mapping:
            label = class_mapping[label]
        if label in query_labels:
            labels.append(label)
            bndbox = member.find('bndbox')
            xmin = int(bndbox.find('xmin').text)
            ymin = int(bndbox.find('ymin').text)
            xmax = int(bndbox.find('xmax').text)
            ymax = int(bndbox.find('ymax').text)
            boxes.append([xmin, ymin, xmax, ymax])
    return labels, boxes

def iou(box1, box2):
    """Calculate Intersection Over Union (IOU) of two bounding boxes."""
    x1, y1, x2, y2 = box1
    x1_p, y1_p, x2_p, y2_p = box2

    xi1 = max(x1, x1_p)
    yi1 = max(y1, y1_p)
    xi2 = min(x2, x2_p)
    yi2 = min(y2, y2_p)
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)

    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x2_p - x1_p) * (y2_p - y1_p)
    union_area = box1_area + box2_area - inter_area

    iou = inter_area / union_area
    return iou

true_labels = []
predicted_labels = []

for frame_name in os.listdir(frame_dir):
    if frame_name.endswith(".jpg"):
        frame_path = os.path.join(frame_dir, frame_name)
        annotation_path = os.path.join(annotation_dir, frame_name.replace(".jpg", ".xml"))
        
        img = Image.open(frame_path).convert("RGB")
        if img is None:
            print(f"Error reading image: {frame_path}")
            continue
        
        # Get ground truth labels and boxes
        if not os.path.exists(annotation_path):
            print(f"Annotation file not found: {annotation_path}")
            continue
        gt_labels, gt_boxes = get_ground_truth_labels(annotation_path)

        # Convert image to tensor
        transform = T.Compose([T.ToTensor()])
        img_tensor = transform(img).to(device)

        # Generate bounding boxes using Faster R-CNN
        with torch.no_grad():
            predictions = faster_rcnn([img_tensor])
        pred_boxes = predictions[0]['boxes'].cpu().numpy()
        pred_scores = predictions[0]['scores'].cpu().numpy()

        detected_labels = []
        detected_boxes = []

        for box, score in zip(pred_boxes, pred_scores):
            if score > 0.5:  # Adjust threshold as needed
                xmin, ymin, xmax, ymax = box
                cropped_img = img.crop((xmin, ymin, xmax, ymax))
                inputs = clip_processor(text=queries, images=cropped_img, return_tensors="pt", padding=True).to(device)
                outputs = clip_model(**inputs)
                logits_per_image = outputs.logits_per_image.softmax(dim=1).detach().cpu().numpy().flatten()

                best_idx = logits_per_image.argmax()
                detected_label = query_labels[best_idx]
                confidence = logits_per_image[best_idx]

                detected_labels.append((detected_label, confidence, [xmin, ymin, xmax, ymax]))

        # For visualization and evaluation
        frame_true_labels = []
        frame_predicted_labels = []
        matched_predictions = [False] * len(detected_labels)

        for label, gt_box in zip(gt_labels, gt_boxes):
            frame_true_labels.append(label)
            matched = False
            for i, (detected_label, _, detected_box) in enumerate(detected_labels):
                if iou(gt_box, detected_box) > 0.5 and not matched_predictions[i]:
                    frame_predicted_labels.append(detected_label)
                    matched_predictions[i] = True
                    matched = True
                    break
            if not matched:
                frame_predicted_labels.append("none")

        for i, (detected_label, _, detected_box) in enumerate(detected_labels):
            if not matched_predictions[i]:
                frame_true_labels.append("none")
                frame_predicted_labels.append(detected_label)

        true_labels.extend(frame_true_labels)
        predicted_labels.extend(frame_predicted_labels)

# Print the true and predicted labels for debugging
print("True Labels:", true_labels)
print("Predicted Labels:", predicted_labels)

# Ensure the lengths are equal
min_length = min(len(true_labels), len(predicted_labels))
filtered_true_labels = true_labels[:min_length]
filtered_predicted_labels = predicted_labels[:min_length]

# Calculate confusion matrix
conf_matrix = confusion_matrix(filtered_true_labels, filtered_predicted_labels, labels=query_labels + ["none"])
print("Confusion Matrix:")
print(conf_matrix)

# Calculate recall for each class
recall = recall_score(filtered_true_labels, filtered_predicted_labels, average=None, labels=query_labels)
print("Recall for each class:")
print(recall)

# Weighted average recall
weighted_recall = recall_score(filtered_true_labels, filtered_predicted_labels, average='weighted', labels=query_labels)
print("Weighted Recall:")
print(weighted_recall)

# Plotting the confusion matrix
plt.figure(figsize=(10, 7))
sns.heatmap(conf_matrix, annot=True, fmt="d", xticklabels=query_labels + ["none"], yticklabels=query_labels + ["none"], cmap="Blues")
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()


# Using CLIP and BLIP

In [None]:
import os
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
import torchvision.transforms as T
from torchvision.models.detection import FasterRCNN, fasterrcnn_resnet50_fpn
import matplotlib.pyplot as plt
from pymongo import MongoClient, errors
import datetime

# Define the folder containing images
folder_path = "keyframes/00102"

# Define the paths to the weights
fasterrcnn_weights_path = "weights/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth"
resnet50_weights_path = "weights/resnet50-0676ba61.pth"

# Check if the files exist
assert os.path.exists(fasterrcnn_weights_path), "Faster R-CNN weights file not found!"
assert os.path.exists(resnet50_weights_path), "ResNet50 weights file not found!"

# Print paths to verify
print(f"Faster R-CNN weights path: {fasterrcnn_weights_path}")
print(f"ResNet50 weights path: {resnet50_weights_path}")

# Load the ResNet50 backbone with local weights
from torchvision.models import resnet50
backbone = resnet50(pretrained=False)
backbone_state_dict = torch.load(resnet50_weights_path, map_location=torch.device('cpu'))

# Remove the fully connected layer weights from the state dictionary
backbone_state_dict.pop("fc.weight", None)
backbone_state_dict.pop("fc.bias", None)

# Load the state dictionary with strict=False to ignore missing keys
backbone.load_state_dict(backbone_state_dict, strict=False)

# Create a custom backbone with FPN from the loaded ResNet50 backbone
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

# Use the backbone with FPN, ensuring it uses the locally loaded weights
backbone_with_fpn = resnet_fpn_backbone('resnet50', pretrained=False, norm_layer=torch.nn.BatchNorm2d)
backbone_with_fpn.body.load_state_dict(backbone.state_dict(), strict=False)

# Load the Faster R-CNN model with the custom backbone
detection_model = FasterRCNN(backbone=backbone_with_fpn, num_classes=91)  # Use the backbone explicitly
detection_model.load_state_dict(torch.load(fasterrcnn_weights_path, map_location=torch.device('cpu')))
detection_model.eval()

# Load the CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load the BLIP captioning model and processor
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# MongoDB setup with error handling
try:
    client = MongoClient('mongodb://localhost:27017/', serverSelectionTimeoutMS=5000)
    client.server_info()  # Trigger exception if cannot connect to db
    db = client['object_detection']
    collection = db['detected_objects']
except errors.ServerSelectionTimeoutError as err:
    print("Failed to connect to MongoDB server:", err)
    exit(1)

# Transform for the object detection model
transform = T.Compose([T.ToTensor()])

# Function to generate captions using BLIP
def generate_caption(image):
    inputs = blip_processor(images=image, return_tensors="pt")
    out = blip_model.generate(**inputs)
    caption = blip_processor.decode(out[0], skip_special_tokens=True)
    return caption

# Process each image in the folder
for filename in os.listdir(folder_path):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        image_path = os.path.join(folder_path, filename)
        image = Image.open(image_path)

        # Transform image for the detection model
        image_tensor = transform(image)

        # Get bounding boxes
        with torch.no_grad():
            detections = detection_model([image_tensor])[0]

        # Filter out low-confidence detections
        threshold = 0.5
        boxes = [box for box, score in zip(detections['boxes'], detections['scores']) if score > threshold]

        detected_objects = []

        # Use BLIP to generate captions for objects within bounding boxes
        for box in boxes:
            xmin, ymin, xmax, ymax = box.int().numpy()
            cropped_image = image.crop((xmin, ymin, xmax, ymax))
            caption = generate_caption(cropped_image)
            inputs = clip_processor(text=[caption], images=cropped_image, return_tensors="pt", padding=True)
            outputs = clip_model(**inputs)
            probs = outputs.logits_per_image.softmax(dim=1).detach().cpu().numpy()[0]
            detected_label = caption  # Use the generated caption as the label
            confidence = probs.max()

            detected_objects.append({
                "box": [xmin, ymin, xmax, ymax],
                "label": detected_label,
                "confidence": float(confidence)
            })

            # Prepare the data to be stored in MongoDB
            detected_object = {
                "filename": filename,
                "label": detected_label,
                "confidence": float(confidence),
                "box": [int(xmin), int(ymin), int(xmax), int(ymax)],
                "timestamp": datetime.datetime.utcnow()
            }

            # Insert the data into MongoDB
            collection.insert_one(detected_object)
            print(f"Image: {filename}, Detected {detected_label} with confidence {confidence:.4f} within box {box}")

        # Optionally, display the image with detected bounding boxes and labels
        plt.imshow(image)
        plt.axis('off')
        ax = plt.gca()
        for obj in detected_objects:
            xmin, ymin, xmax, ymax = obj['box']
            detected_label = obj['label']
            confidence = obj['confidence']
            rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color='red', linewidth=2)
            ax.add_patch(rect)
            plt.text(xmin, ymin, f'{detected_label} {confidence:.2f}', bbox=dict(facecolor='yellow', alpha=0.5))

        plt.show()
