In [None]:
# ========================
# Import Libraries
# ========================
import os
import time 
import gc
import pickle
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, Audio, Markdown
import torch
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from transformers.image_utils import ImageFeatureExtractionMixin
from transformers.utils import send_example_telemetry
from tensorflow.keras.preprocessing import image
from torchvision.datasets import CIFAR100
from openai import OpenAI 
from dotenv import load_dotenv
import base64
import openai

In [None]:
# ========================
# Model Initialization
# ========================

send_example_telemetry("zeroshot_object_detection_with_owlvit_notebook", framework="pytorch")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
mixin = ImageFeatureExtractionMixin()

# Device configuration
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available. Using GPU.")
else:
    device = torch.device("cpu")
    print("GPU is not available. Using CPU.")

# ========================
# OpenAI API Setup
# ========================

MODEL="gpt-4o"
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "YOUR_API_KEY_HERE"))

openai.api_key = "YOUR_API_KEY_HERE"

def get_embeddings(text):
    response = openai.embeddings.create(
        model="text-embedding-ada-002",  # Use the appropriate model for embeddings
        input=text
    )
    embeddings = response.data[0].embedding
    return embeddings

In [None]:
# ========================
# Visualization Function
# ========================

def plot_predictions(input_image, text_queries, scores, boxes, labels, num, act, score_threshold, show=True):
  fig, ax = plt.subplots(1, 1, figsize=(3, 3))
  ax.imshow(input_image, extent=(0, 1, 1, 0))
  ax.set_axis_off()

  detection_count = 0
  for score, box, label in zip(scores, boxes, labels):
    detection_count += 1
    if score < score_threshold or detection_count > num:
      continue
    if act == 1:
       box = [0.5, 0.5, 1.0, 1.0]
    cx, cy, w, h = box
    #print(box)
    ax.plot([cx-w/2, cx+w/2, cx+w/2, cx-w/2, cx-w/2],
            [cy-h/2, cy-h/2, cy+h/2, cy+h/2, cy-h/2], "r")
    ax.text(
        cx - w / 2,
        cy + h / 2 + 0.015,
        f"{text_queries}: {score:1.5f}",
        ha="left",
        va="top",
        color="red",
        bbox={
            "facecolor": "white",
            "edgecolor": "red",
            "boxstyle": "square,pad=.3"
        })
  if show:
      
      plt.show()

In [None]:
# ========================
# Image Description Generation Loop
# ========================

plot_interval = 100  # Plot every 100 images for visualization
score_threshold = 0.001

for day in ['day1', 'day2', 'day3', 'day4', 'day5', 'day6', 'day7', 'day8', 'day9']:
    print('Generating language descriptions for', day, '...')
    print('Querying GPT-4o for image descriptions...')

    # Set the data path 
    #day = 'day3'  
    rel_data_path = f'data/{day}/images/'

    dir_path = os.path.join(Path.cwd().parent, rel_data_path)
    imgs_path = []
    for path in os.listdir(dir_path):
        imgs_path.append(path)
    imgs_path.sort()
    descriptions = []

    # Load existing dictionaries or initialize empty ones
    try:
        pickle_file_path = 'image_descriptions_dict.pickle'
        with open(pickle_file_path, 'rb') as f:
            image_descriptions_dict = pickle.load(f)


        pickle_file_path = 'text_embeddings_dict.pickle'
        with open(pickle_file_path, 'rb') as f:
            text_embeddings_dict = pickle.load(f)

    except FileNotFoundError:
        print("Dictionaries not found. Initializing empty dictionaries.")
        text_embeddings_dict = {}
        image_descriptions_dict = {}


    for image_index in range(0, len(imgs_path)):
        try:
            # Image encoding for API
            IMAGE_PATH = f'{dir_path}%s'%imgs_path[image_index]
            
            def encode_image(image_path):
                with open(image_path, "rb") as image_file:
                    return base64.b64encode(image_file.read()).decode("utf-8")
            
            base64_image = encode_image(IMAGE_PATH)
            
            # OpenAI API call for image description
            response = client.chat.completions.create(
                model=MODEL,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant that describes the images for a robotic vacuum cleaner."},
                    {"role": "user", "content": [
                        {"type": "text", "text": "You are an AI system that helps a smart robotic vacuum cleaner. This is an image of the floor captured by the camera of the vacuum cleaner robot.\
                        What does the floor look like? in terms of color and pattern. (gray carpet, wooden floor, zigzag carpet, etc.) Keep your answer short.\
                        What do you see on the floor? sprinkles, crumbs, pet feces, nail, chair legs, USB drive, socks\
                        If there is nothing on the floor,just say NONE.\
                        How many of this item is in this image? if it is NONE or uncountable (like crumbs) say 1.\
                        Should the vacuum cleaner suck this item or avoid it? Do not give explanations.\
                        output your answer in this format: item on the floor, avoid/suck, number of item, floor description\
                        examples: \
                        sprinkles, suck, 1, zigzag carpet\
                        NONE, suck, 1, zigzag carpet\
                        crumbs, suck, 1, gray carpet\
                        nail, avoid, 1, wooden floor\
                        paper clip, avoid, 2, wooden floor\
                        NONE, suck, 1, gray carpet\
                        chair leg, avoid, 1, gray carpet\
                    "},
                        {"type": "image_url", "image_url": {
                            "url": f"data:image/png;base64,{base64_image}"}
                        }
                    ]}
                ],
                temperature=0.0,
            )
            
            # Store results
            descriptions.append(response.choices[0].message.content)
            des = response.choices[0].message.content
            image_descriptions_dict[f"{rel_data_path}{imgs_path[image_index]}"] = des.lower()

            item = response.choices[0].message.content
            
            # Generate embeddings for new items
            if item.lower() not in text_embeddings_dict:
                text_embeddings_dict[item.lower()] = get_embeddings(item.lower())


        except Exception as e:
            print(f"An error occurred: {e}")
            print("Retrying in 5 seconds...")
            time.sleep(5)  
            continue  


    # Save generated dictionaries
    pickle_file_path = 'image_descriptions_dict.pickle'
    with open(pickle_file_path, 'wb') as f:
        pickle.dump(image_descriptions_dict, f)
    #print(f"Dictionary saved to {pickle_file_path}")

    pickle_file_path = 'text_embeddings_dict.pickle'
    with open(pickle_file_path, 'wb') as f:
        pickle.dump(text_embeddings_dict, f)
    #print(f"Dictionary saved to {pickle_file_path}")

    # Create image-to-embedding mapping
    image_embeddings_dict = {}
    for img_path, description in image_descriptions_dict.items():
        if description in text_embeddings_dict:
            image_embeddings_dict[img_path] = text_embeddings_dict[description]
        else:
            print(f"⚠️  Warning: No embedding found for description: {description}")



    print('Plotting sample images for visualization ... \n')

    imgs_path = sorted(os.listdir(dir_path)) 
    model = model.to(device)
    model.eval()    
    
    dict_obj = {}
    dict_act = {}
    dict_floor = {}
    dict_num = {}


    for key, value in image_descriptions_dict.items():
        items = [item.strip() for item in value.split(",")]
        dict_obj[key] = items[0] if len(items) > 0 else ""
        dict_act[key] = items[1] if len(items) > 1 else ""
        dict_num[key] = int(items[2]) if len(items) > 2 else ""

        dict_floor[key] = items[3] if len(items) > 3 else ""



    for image_index, img_name in enumerate(imgs_path):


        try:
            # ------------------------
            # Load and preprocess image
            # ------------------------
            img_path = os.path.join(dir_path, img_name)
            image1 = Image.open(img_path)
            image = Image.fromarray(np.uint8(image1)).convert("RGB")

            text_queries = dict_obj[f"{rel_data_path}{img_name}"]

            # Prepare inputs
            inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device)

            # ------------------------
            # Model inference
            # ------------------------
            start = time.time()
            with torch.no_grad():
                outputs = model(**inputs)
            end = time.time()
            #print(f"Inference time: {end - start:.3f} s")

            # ------------------------
            # Post-processing
            # ------------------------
            image_size = model.config.vision_config.image_size
            image_resized = mixin.resize(image, image_size)
            input_image = np.asarray(image_resized).astype(np.float32) / 255.0

            logits = torch.max(outputs["logits"][0], dim=-1)
            scores = torch.sigmoid(logits.values).cpu().detach().numpy()
            labels = logits.indices.cpu().detach().numpy()
            boxes = outputs["pred_boxes"][0].cpu().detach().numpy()

            # Sort predictions by score
            data = list(zip(scores, boxes, labels))
            sorted_data = sorted(data, key=lambda x: x[0], reverse=True)
            sorted_scores = [d[0] for d in sorted_data]
            sorted_boxes = [d[1] for d in sorted_data]
            sorted_labels = [d[2] for d in sorted_data]
    

            # ------------------------
            # Plot & Save Labels
            # ------------------------
            num_objects = dict_num[f"{rel_data_path}{img_name}"]
            if image_index % plot_interval == 0:
                show = True
            else:
                show = False

            act = 0 if 'avoid' in dict_act[f"{rel_data_path}{img_name}"] else 1

            if image_index % plot_interval == 0:
                print('IMAGE PATH:', rel_data_path + img_name)
                print('Language description:', image_descriptions_dict[f"{rel_data_path}{img_name}"])

            plot_predictions(input_image, text_queries, sorted_scores, sorted_boxes, sorted_labels, num_objects, act=act, score_threshold=score_threshold, show=show)
            plt.close('all')  # prevent figure accumulation

            for box in sorted_boxes[:num_objects]:
                rounded_box = [round(num, 5) for num in box]
                if image_index % plot_interval == 0:
                    print("Predicted bounding box:", rounded_box)
            if image_index % plot_interval == 0:
                print('------------------------------------------')

            

            write_data = [(act, sorted_boxes[j]) for j in range(num_objects)]

            # Save labels to file
            labels_dir = dir_path.replace('images', 'labels')
            os.makedirs(labels_dir, exist_ok=True)
            label_filename = os.path.join(labels_dir, img_name.replace('.jpg', '.txt'))

            with open(label_filename, 'w') as f:
                for label, box in write_data:
                    if act == 1:
                        line = '1.0 0.5 0.5 1.0 1.0\n'
                    else:
                        line = f"{label} {box[0]} {box[1]} {box[2]} {box[3]}\n"
                    f.write(line)


        finally:
            # ------------------------
            # Cleanup to prevent freezing
            # ------------------------
            del inputs, outputs, image1, image, image_resized
            gc.collect()
            torch.cuda.empty_cache()
    



In [None]:
image_embeddings_dict = {}
for img_path, description in image_descriptions_dict.items():
    if description in text_embeddings_dict:
        image_embeddings_dict[img_path] = text_embeddings_dict[description]
    else:
        print(f"⚠️  Warning: No embedding found for description: {description}")

