<a href="https://colab.research.google.com/github/v-chabaux/computer-vision/blob/main/dinov2_image_retrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [35]:
# SRC : https://github.com/roboflow/notebooks/tree/main
!pip install faiss-gpu
!pip install supervision

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libneon27
The following NEW packages will be installed:
  davfs2 libneon27
0 upgraded, 2 newly installed, 0 to remove and 19 not upgraded.
Need to get 258 kB of archives.
After this operation, 627 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libneon27 amd64 0.32.2-1 [102 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 davfs2 amd64 1.6.1-1 [155 kB]
Fetched 258 kB in 0s (1,122 kB/s)
Preconfiguring packages ...
Selecting previously unselected package libneon27:amd64.
(Reading database ... 120874 files and directories currently installed.)
Preparing to unpack .../libneon27_0.32.2-1_amd64.deb ...
Unpacking libneon27:amd64 (0.32.2-1) ...
Selecting previously unselected package davfs2.
Preparing to unpack .../davfs2_1.6.1-1_amd64.deb ...
Unpacking davfs2 (1.6.1-1) ...
Set

In [44]:
import sys
import os
import faiss
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import cv2
import json
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
import supervision as sv
from webdav3.client import Client
from google.colab import drive
drive.mount('/content/drive')

# Load an image and return a tensor that can be used as an input to DINOv2.
def load_image(img: str) -> torch.Tensor:
    img = Image.open(img)
    transformed_img = transform_image(img)[:3].unsqueeze(0)
    return transformed_img

# Create an index that contains all of the images in the specified list of files.
def create_index(files: list) -> faiss.IndexFlatL2:
    index = faiss.IndexFlatL2(384)
    all_embeddings = {}
    with torch.no_grad():
        for i, file in enumerate(tqdm(files)):
          try:
            embeddings = dinov2_vits14(load_image(file).to(device))
            embedding = embeddings[0].cpu().numpy()
            all_embeddings[file] = np.array(embedding).reshape(1, -1).tolist()
            index.add(np.array(embedding).reshape(1, -1))
          except: pass
    with open("all_embeddings.json", "w") as f:
        f.write(json.dumps(all_embeddings))
    faiss.write_index(index, "data.bin")
    return index, all_embeddings

# Search the index for the images that are most similar to the provided image.
def search_index(index: faiss.IndexFlatL2, embeddings: list, k: int = 3) -> list:
    D, I = index.search(np.array(embeddings[0].reshape(1, -1)), k)
    return I[0]

Mounted at /content/drive


In [None]:
if torch.cuda.is_available() : print("Used GPU : {}\n".format(torch.cuda.get_device_name(0)))
else: print("No GPU available. CPU is used\n")

dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
dinov2_vits14.to(device)
transform_image = T.Compose([T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])])

folder="/content/drive/MyDrive/Corpus"
files = ["/".join([folder,path]) for path in os.listdir(folder) if path.endswith(".jpg")]
#print(files)
data_index, all_embeddings = create_index(files)


In [None]:
search_file = "/content/drive/MyDrive/Corpusimage_887.jpg"

img = cv2.resize(cv2.imread(search_file), (416, 416))
print("Input image:")
%matplotlib inline
sv.plot_image(image=img, size=(6, 6))
print("*" * 20)
with torch.no_grad():
    embedding = dinov2_vits14(load_image(search_file).to(device))
    indices = search_index(data_index, np.array(embedding[0].cpu()).reshape(1, -1), k=10)
    for i, index in enumerate(indices):
        print()
        print(f"Image {i}: {files[index]}")
        img = cv2.resize(cv2.imread(files[index]), (416, 416))
        %matplotlib inline
        sv.plot_image(image=img, size=(6,6))