In [1]:
import os
import time
import pandas as pd
import numpy as np
import openai
from tqdm import tqdm
tqdm.pandas()

In [2]:
openai.api_key = os.getenv("OPENAI_API_KEY", None)

In [3]:
df_tagged_imgs = pd.read_csv("./tagged_dataset/final-tags.csv")
df_tagged_imgs = df_tagged_imgs.loc[:, ~df_tagged_imgs.columns.str.contains('^Unnamed')]
df_tagged_imgs.tail(2)

Unnamed: 0,image_name,image_path,tags,extended_tags,background_category
1308,280678d5-383f-450e-856e-08e2c11288e7.png,/media/pixis/pixis/Ravi_workspace/Entity-taggi...,"['circle', 'cotton candy', 'cube', 'green', 'h...","['minimalistic background', 'top view of stone...",MINIMALISTIC
1309,891a97af-592f-4837-bd62-693793401ff7.png,/media/pixis/pixis/Ravi_workspace/Entity-taggi...,"['crack', 'marble', 'stone', 'white']","['minimalistic background', 'kept on old shabb...",MINIMALISTIC


In [4]:
import ast
def get_tags_embeddings(tags_list):
    tags_list = ast.literal_eval(tags_list)
    tag_line = ', '.join(tags_list)
    embeddings = []
    response = openai.embeddings.create(input=tag_line, model="text-embedding-3-large")
    embeddings = response.data[0].embedding
    return embeddings

start_time = time.time()
df_tagged_imgs["tag_embeddings"] = df_tagged_imgs["extended_tags"].progress_apply(get_tags_embeddings)
end_time = time.time()
execution_time = end_time - start_time
# print(execution_time)

100%|██████████| 1310/1310 [17:07<00:00,  1.28it/s]  


In [5]:
df_tagged_imgs.to_csv("./tagged_dataset/embedded_tags.csv")

In [25]:
from PIL import Image
import numpy as np


def get_numpy_array(image_path):
    img = Image.open(image_path)
    return img

In [26]:
df_tagged_imgs["numpy_arr_img"] = df_tagged_imgs["image_path"].progress_apply(get_numpy_array)

100%|██████████| 1310/1310 [00:18<00:00, 71.74it/s] 


In [5]:
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

def get_top_n_plots(top_image_arrays):
    images_per_row = 3
    total_images = len(top_image_arrays)
    rows = (total_images + images_per_row - 1) // images_per_row

    # Create a figure with the specified size
    fig, axes = plt.subplots(nrows=rows, ncols=images_per_row, figsize=(12, rows * 3))

    # Loop through the top_image_arrays and plot them
    for idx, img in enumerate(top_image_arrays):
        row = idx // images_per_row
        col = idx % images_per_row
        axes[row, col].imshow(img)
        axes[row, col].axis('off')

    # Hide any empty subplots
    for i in range(total_images, rows * images_per_row):
        row = i // images_per_row
        col = i % images_per_row
        axes[row, col].axis('off')

    # Adjust layout and display the plot
    plt.tight_layout()
    plt.show()


def get_top_n_similar_embeddings(new_embedding, tags_df, top_n=10):
    # Ensure the embeddings are in the correct format
    tag_embeddings = np.array(tags_df['tag_embeddings'].tolist())
    
    # Calculate the cosine similarity between the new embedding and the embeddings in the dataframe
    similarities = cosine_similarity([new_embedding], tag_embeddings)[0]
    
    top_indices = np.argsort(similarities)[-top_n:][::-1]
    
    top_similarities = similarities[top_indices]
    top_embeddings = tags_df.iloc[top_indices]
    
    # Get the image file paths of the top_n most similar embeddings
    top_image_arrays = top_embeddings['numpy_arr_img'].tolist()
    
    return top_image_arrays, top_similarities

def embed_query(input_query):
    input_query = ', '.join(input_query)
    response = openai.embeddings.create(input=input_query, model="text-embedding-3-large")
    embeddings = response.data[0].embedding
    return embeddings

In [13]:
df_tagged_imgs = df_tagged_imgs.loc[:, ~df_tagged_imgs.columns.str.contains('^Unnamed')]

In [6]:
query = ["sea"]
curr_embbedding = embed_query(query)
print(curr_embbedding)
# top_image_arrays, top_similarities = get_top_n_similar_embeddings(curr_embbedding, df_tagged_imgs, top_n=9)
# get_top_n_plots(top_image_arrays)

[-0.024967990815639496, 0.000279679661616683, 0.021536072716116905, 2.3291264369618148e-05, 0.000540172855835408, 0.019725656136870384, 0.03787703812122345, 0.012192754074931145, -0.06819756329059601, 0.023220546543598175, -0.014546294696629047, 0.04385928064584732, 0.006722151301801205, -0.012531223706901073, -0.01176770031452179, 0.005281690042465925, -0.018182868137955666, -0.010642094537615776, -0.022307466715574265, -0.01185428537428379, 0.0058995927684009075, -0.0033728827256709337, 0.012137655168771744, 0.014774564653635025, 0.0035696669947355986, 0.020733192563056946, 0.03302827477455139, 0.027612771838903427, -0.022590836510062218, 0.008461724035441875, 0.01697067730128765, 0.0678827092051506, -0.010287882760167122, -0.0028927288949489594, -0.03727881610393524, 0.04294620454311371, -0.007918599992990494, -0.010886106640100479, 0.006013727746903896, -0.040899645537137985, -0.006777250673621893, -0.008217711932957172, -0.0104846665635705, 0.04436304792761803, -0.0148532781749963

In [27]:
df_tagged_imgs.to_csv("./tagged_dataset/pil_embedded_tags.csv")

In [7]:
df_tagged_imgs = pd.read_csv("./tagged_dataset/pil_embedded_tags.csv")

In [9]:
print(type(df_tagged_imgs["tag_embeddings"]))
temp = df_tagged_imgs["tag_embeddings"].iloc[0]
print(type(temp))

<class 'pandas.core.series.Series'>
<class 'str'>
