In [7]:
import chromadb
from logging import INFO, ERROR

VECTOR_DB_PATH = "./vectordatabase_beer_1"

def init_vectordb():
    try:
        chromadb.PersistentClient(path=VECTOR_DB_PATH)
        print("Setup vectordb OK")
    except Exception as e:
        print(f"Setup vectordb failed! - {str(e)}")
    
def get_chroma_collection():
    try:
        client = chromadb.PersistentClient(path=VECTOR_DB_PATH)
        collection = client.get_or_create_collection(name="beer_similarity_check_1", metadata={"hnsw:space": "cosine"})
        
        return collection
    except Exception as e:
        print(f"Get vectordb client error {str(e)}")

In [8]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import vgg16
from PIL import Image
import os

TARGET_SIZE = (224, 224)

def process_image(image_path, normalize=True):
    """Load and preprocess an image."""
    img = Image.open(image_path).convert('RGB')
    transform_list = [
        transforms.Resize(TARGET_SIZE),
        transforms.ToTensor()
    ]
    if normalize:
        transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    
    transform = transforms.Compose(transform_list)
    img = transform(img)
    img = img.unsqueeze(0)  # Add batch dimension
    return img

class EmbeddingModel:
    
    def __init__(self):
        model = vgg16(pretrained=True)
        self.embedding_model = model.features  # Remove the classifier
        self.embedding_model.eval()
        
    def embed(self, img_path, normalize=True):
        """Embed the image using the model."""
        if isinstance(img_path, str):
            assert os.path.exists(img_path), f"Image path {img_path} not found!"
            img_tensor = process_image(img_path, normalize)
        else:
            img_tensor = img_path
            if normalize:
                img_tensor = self._normalize(img_tensor)
        
        with torch.no_grad():
            representation = self.embedding_model(img_tensor)
        
        return representation.flatten().numpy()
    
    def _normalize(self, img_tensor):
        """Normalize the image tensor."""
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        return normalize(img_tensor)

In [9]:
init_vectordb()
vectordb_coll=get_chroma_collection()
count=0
img_path="output_knn/train"
embedding_model = EmbeddingModel()
for brand in os.listdir(img_path):
    brand_path = os.path.join(img_path, brand)
    for f in os.listdir(brand_path):
        count+=1
        f_path=os.path.join(brand_path,f)
        embed = embedding_model.embed(f_path,True)
        vectordb_coll.add(
                    embeddings=embed.tolist(),
                    metadatas={
                        "url": f_path,
                        "brand": brand,
                    },
                ids=[str(count)]
                )

Setup vectordb OK


In [12]:
test_image="/home/edabk-lab/quangnghiem/heineken/output_knn/val/Huda/000035quan_nhau_2.jpg"
embed_feature=embedding_model.embed(test_image).tolist()
results = vectordb_coll.query(
    query_embeddings=[embed_feature], # A list of numpy arrays representing images
    n_results=50
)


In [13]:
results

{'ids': [['5846',
   '3323',
   '2927',
   '3085',
   '5607',
   '5778',
   '2708',
   '5684',
   '5744',
   '5567',
   '6148',
   '3027',
   '2720',
   '3005',
   '5422',
   '6069',
   '6869',
   '2880',
   '3767',
   '5400',
   '6115',
   '5981',
   '5879',
   '5904',
   '3039',
   '5757',
   '5837',
   '2932',
   '5638',
   '2969',
   '2712',
   '5974',
   '5949',
   '3986',
   '5536',
   '6062',
   '1593',
   '333',
   '470',
   '3096',
   '5826',
   '2923',
   '3060',
   '6925',
   '3073',
   '3108',
   '5673',
   '2750',
   '2899',
   '5258']],
 'distances': [[0.7422964572906494,
   0.7473919987678528,
   0.7594253420829773,
   0.7643394470214844,
   0.7692051529884338,
   0.7712525129318237,
   0.7784634828567505,
   0.7828623056411743,
   0.7845882773399353,
   0.788404107093811,
   0.7898454666137695,
   0.7908550500869751,
   0.7909548878669739,
   0.7923252582550049,
   0.7949419617652893,
   0.7960531711578369,
   0.7961166501045227,
   0.7967243790626526,
   0.799093484878