In [1]:
import cv2
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
import os
from io import BytesIO
import base64
import math

In [2]:
data_dir = "Stanford_Online_Products"
classes_dir = [d for d in glob(data_dir+"/*") if os.path.isdir(d)]

PER_CLASS_IMAGE_COUNT = 12
IMAGE_SIZE = (256, 256)

image_paths = [image for dir in classes_dir for image in glob(dir+"/*")[:PER_CLASS_IMAGE_COUNT]]
labels = [os.path.basename(os.path.dirname(image_path)) for image_path in image_paths]

print(len(image_paths))

144


In [3]:
payloads = pd.DataFrame({"image_path": image_paths, "label": labels})
payloads.head()

Unnamed: 0,image_path,label
0,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final
1,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final
2,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final
3,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final
4,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final


In [4]:
def resize_image(image_path):
    image = cv2.imread(image_path)
    resized_image = cv2.resize(image, IMAGE_SIZE)

    return resized_image

def image_to_base64(image):
    ret, buffer = cv2.imencode('.jpeg', image)
    encoded_image = base64.b64encode(buffer.tobytes()).decode("utf-8")

    return encoded_image


resized_images = list(map(lambda x: resize_image(x), payloads["image_path"]))
base64_images = list(map(lambda x: image_to_base64(x), resized_images))

payloads["base64"] = base64_images

In [5]:
payloads

Unnamed: 0,image_path,label,base64
0,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...
1,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...
2,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...
3,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...
4,Stanford_Online_Products\bicycle_final\1110851...,bicycle_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...
...,...,...,...
139,Stanford_Online_Products\toaster_final\1114300...,toaster_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...
140,Stanford_Online_Products\toaster_final\1114471...,toaster_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...
141,Stanford_Online_Products\toaster_final\1114471...,toaster_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...
142,Stanford_Online_Products\toaster_final\1114471...,toaster_final,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQ...


In [6]:
from transformers import AutoImageProcessor, ResNetForImageClassification

processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

inputs = processor(np.array(resized_images), return_tensors="pt")
outputs = model(**inputs)

embeddings = outputs.logits
print(embeddings)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tensor([[-11.6843, -10.2012,  -8.0000,  ...,  -9.0842,  -7.7204,  -6.0618],
        [-11.3121, -10.1082,  -7.5649,  ...,  -8.8721,  -7.6465,  -5.5173],
        [-11.3030,  -9.5538,  -8.3420,  ...,  -9.0864,  -7.4481,  -5.4479],
        ...,
        [-12.2224,  -7.6830, -11.2379,  ..., -11.0365, -11.8333, -10.3787],
        [-11.4611, -10.0367, -11.6826,  ..., -11.7087, -12.5188, -10.3918],
        [-11.0452, -10.6416, -11.1737,  ..., -12.1864, -10.7409, -10.4825]],
       grad_fn=<AddmmBackward0>)


In [7]:
embedding_len = len(embeddings[0])
print(embedding_len)

1000


In [8]:
from dotenv import load_dotenv
load_dotenv()

True

In [9]:
from qdrant_client import QdrantClient

q_client = QdrantClient(
    url=os.getenv("QDRANT_DB_URL"),
    api_key=os.getenv("QDRANT_API_KEY")
)
q_client

<qdrant_client.qdrant_client.QdrantClient at 0x1efd7283160>

In [10]:
from qdrant_client.models import VectorParams, Distance, Record

collection_name = "product_matching"
collection = q_client.recreate_collection(collection_name, vectors_config=VectorParams(size=embedding_len, distance=Distance.COSINE))
collection

  collection = q_client.recreate_collection(collection_name, vectors_config=VectorParams(size=embedding_len, distance=Distance.COSINE))


True

In [11]:
payload_dict = payloads.to_dict(orient="records")
payload_dict

[{'image_path': 'Stanford_Online_Products\\bicycle_final\\111085122871_0.JPG',
  'label': 'bicycle_final',
  'base64': '/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAEAAQADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4v/YP/wCCSX7LX7T/AOyV4S+OHxG1/wAbw6vriX73sWkataw2

In [12]:
records = []

for idx, _ in enumerate(payload_dict):
    record = Record(id=idx, payload=payload_dict[idx], vector=embeddings[idx].tolist())
    records.append(record)

In [13]:
q_client.upload_records(collection_name=collection_name, records=records)

  q_client.upload_records(collection_name=collection_name, records=records)


In [14]:
q_client.retrieve(
    collection_name=collection_name,
    ids=[0, 3, 100],
)

[Record(id=0, payload={'image_path': 'Stanford_Online_Products\\bicycle_final\\111085122871_0.JPG', 'label': 'bicycle_final', 'base64': '/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAEAAQADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4v/YP/wCCSX7LX7T/AOyV4S+OHxG1/wA

In [15]:
q_client.recommend(collection_name=collection_name,
                    positive=[3], 
                    limit=6)

  q_client.recommend(collection_name=collection_name,


[ScoredPoint(id=11, version=0, score=0.9905589, payload={'image_path': 'Stanford_Online_Products\\bicycle_final\\111265328556_3.JPG', 'label': 'bicycle_final', 'base64': '/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAEAAQADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAP