## Import libraries

In [1]:
from sentence_transformers import SentenceTransformer
import torch
import json
import chromadb
import os
import base64
from io import BytesIO
import torchvision.transforms as T
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.engine.text_embedding import TextEmbeddingGenerator
from src.engine.image_embedding import ImageEmbeddingGenerator

## Initial database

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
text_embedding = TextEmbeddingGenerator()
image_embedding = ImageEmbeddingGenerator()

Using cache found in C:\Users\Lenovo/.cache\torch\hub\facebookresearch_dinov2_main


In [5]:
client = chromadb.PersistentClient(path="./chromadb")
text_collection = client.get_or_create_collection(
    name="text_chroma_db",
    metadata={"hnsw:space": "cosine"}
)
image_collection = client.get_or_create_collection(
    name="image_chroma_db",
    metadata={"hnsw:space": "cosine"}
)

In [None]:
with open(r'data\product_injected_categories.json', 'r', encoding='utf-8') as file:
    text_data = json.load(file)

In [None]:
image_data = r'data\images-1-1100'  # Replace with your folder path
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')  # Add more if needed

## Indexing

In [7]:
def image_to_base64(image_path):
    with open(image_path, 'rb') as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
    return encoded_string

In [None]:
for item in text_data:
    text = f"Tên sách: {item['Name']}\n" + f"Nội dung sách: {item['Description']}"
    embedding = await text_embedding.generate_text_embedding(text)
    text_collection.add(
        embeddings=[embedding],
        documents=[text],
        metadatas=[{'id': str(item['Id']), 'name': item['Name'], 'description': item['Description']}],
        ids=[str(item['Id'])],
    )

In [None]:
for subdir, _, files in os.walk(image_data):
    for file in files:
        if file.lower().endswith(image_extensions):
            full_path = os.path.join(subdir, file)
            base64_image = image_to_base64(full_path)
            image_embedding = await image_embedding.generate_image_embedding(base64_image)
            image_collection.add(
                embeddings=[image_embedding],
                metadatas=[{'product_id': os.path.basename(subdir)[5:], 'image_id': os.path.basename(file)[:-4]}],
                ids=[os.path.basename(file)[:-4]],
            )