In [None]:
# Install vector database
# ! pip install vectordb

## Import modules

In [None]:
import os
import numpy as np
from tqdm.notebook import trange, tqdm
from PIL import Image, ImageFont, ImageDraw 
import torch
import clip
import json as js
from docarray import DocList, BaseDoc
from docarray.typing import NdArray
import numpy as np
from vectordb import InMemoryExactNNVectorDB, HNSWVectorDB
from IPython.display import clear_output
from natsort import natsorted
import pandas as pd
from typing import List

## Constants

In [None]:
MODEL = "ViT-B/32"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

METADATA_PATH = "../data/metadata/"
KEYFRAME_PATH = "../data/keyframes/"
FEATURE_PATH = "../data/features/"
MAP_KEYFRAMES = "../data/map-keyframes/"

# Workspace for index database
WORKSPACE_PATH = "./workspace_aic23"

## Re-formating Dataset

#### Get video names

In [None]:
LEN_OF_KEYFRAME_NAME = 4

In [None]:
video_names = [name for name in os.listdir(KEYFRAME_PATH) if name != ".gitkeep"]
print(video_names)

In [None]:
for name in video_names:
    keyframes = [path for path in os.listdir(os.path.join(KEYFRAME_PATH, name))]
    for kf in keyframes:
        img_name = kf.split(".")[0]
        if len(img_name) != LEN_OF_KEYFRAME_NAME:
            changed_path = os.path.join(KEYFRAME_PATH, name, img_name.zfill(4) + ".jpg")
            old_path = os.path.join(KEYFRAME_PATH, name, kf)
            print(f"Change {old_path} to {changed_path}")
            os.rename(old_path, changed_path)
            

## Text embedding

In [None]:
class TextEmbedding():
  def __init__(self):
    self.device = DEVICE
    self.model, _ = clip.load(MODEL, device=self.device)

  def __call__(self, text: str) -> np.ndarray:
    text_inputs = clip.tokenize([text]).to(self.device)
    with torch.no_grad():
        text_feature = self.model.encode_text(text_inputs)[0]
    return text_feature.detach().cpu().numpy()
  
  def __call__(self, texts: list[str]) -> np.ndarray:
    text_inputs = clip.tokenize(texts).to(self.device)
    with torch.no_grad():
        text_feature = self.model.encode_text(text_inputs)[0]
    return text_feature.detach().cpu().numpy()


In [None]:
text_embedding = TextEmbedding()

In [None]:
query = "A blue sky in the background"
query_feat = text_embedding(query)
print(len(query_feat))

In [None]:
querys = ["A blue sky in the background", "People hangout at the beach", "Birds are flying in the sky"]
querys_feat = text_embedding(querys)
print(len(querys_feat))

## Vector Database

### Frame Document Class

In [None]:
class FrameDoc(BaseDoc):
  embedding: NdArray[512]
  video_name = ""
  image_path = ""
  keyframe_id = 0
  actual_idx = 0
  metadata = {}
  
  def __str__(self):
    return f"""
          Video name: {self.video_name}
          Image path: {self.image_path}
          Keyframe Id: {self.keyframe_id}
          Actual keyframe idx: {self.actual_idx}
          Metadata: {self.metadata}
          """

##  Database Handler

In [None]:
class VectorDB:
    text_embedding = TextEmbedding()
    backups = []
    def __init__(self, workspace, type="ANN"):
        #   Approximate Nearest Neighbour based on HNSW algorithm
        if type == "ANN":
            self.DB = HNSWVectorDB[FrameDoc](workspace=workspace)
            
        # Exhaustive search on the embeddings
        else:
            self.DB = InMemoryExactNNVectorDB[FrameDoc](workspace=workspace)
        
        
    def index(self, doc_list: List[FrameDoc]):    
        # Index database
        self.DB.index(inputs=DocList[FrameDoc](doc_list))
        
    def search(self, query_text: str, topk=100):
        query_doc = FrameDoc(embedding=self.text_embedding(query_text))
        return self.DB.search(inputs=DocList[FrameDoc]([query_doc]), limit=topk)[0]
    
    def delete(self, del_doc_list: List[FrameDoc]):
        self.DB.delete(docs=DocList[FrameDoc](del_doc_list))

### Needed functions 

### Get all features files

In [None]:
def get_all_feats():
    return [os.path.join(FEATURE_PATH, file) for file in os.listdir(FEATURE_PATH) if file.endswith(".npy")]

In [None]:
all_feat_files = get_all_feats()
print(all_feat_files)
print(len(all_feat_files))

### Create all the Docs

In [None]:
def get_all_docs(npy_files):
    doc_list = []
    for feat_npy in npy_files:
        video_name = feat_npy[feat_npy.find("L"):].split('.')[0]
        feats_arr = np.load(os.path.join(feat_npy))
        # Load metadata
        metadata = {}
        with open(os.path.join(METADATA_PATH, video_name + ".json")) as meta_f:
            metadata = js.load(meta_f)
            map_kf = pd.read_csv(os.path.join(MAP_KEYFRAMES, video_name + ".csv"), usecols=["frame_idx"])
            
            for frame_idx, feat in enumerate(feats_arr):
                image_path = os.path.join(KEYFRAME_PATH, video_name, f"{frame_idx + 1:04d}.jpg")
                doc_list.append(FrameDoc(embedding=feat, video_name=video_name, image_path=image_path, 
                                        keyframe_id=frame_idx+1, actual_idx=map_kf["frame_idx"][frame_idx], metadata=metadata))
                
    return doc_list

In [None]:
doc_list = get_all_docs(all_feat_files)

In [None]:
print(len(doc_list))

In [None]:
print(doc_list[100])

## Visualization functions

In [None]:
def get_images(result_matches, drawed = None):
    images = []
    for res in result_matches:
        img = Image.open(res.image_path)
        if drawed:
            draw = ImageDraw.Draw(img)
            font = ImageFont.truetype("arial.ttf", 50)
            draw.text(xy=(5, 5), text=f"{res.video_name}, {res.actual_idx}", align="left", fill=(255,0,0,255), font=font)
        images.append(img)
        
    return images

In [None]:
def visualize(imgs: List[Image.Image]) -> None:
    rows = len(imgs) // 2 # see more clearly
    if not rows:
        rows += 1
    cols = len(imgs) // rows
    if rows * cols < len(imgs):
        rows += 1
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))

    display(grid)

## DEMO

### Create DB

In [None]:
DB = VectorDB("DB1")

In [None]:
# Chi index 1 lan
DB.index(doc_list) # ~ 3 mins 

### Query

#### Query text

Đoạn video về một người phụ nữ mặc áo màu vàng đang bỏ rác vào thùng rác. Thùng rác màu xanh lá đậm và nắp thùng màu đỏ. Rác đang bỏ vào thùng cho biết đó là 1kg baby spinach.

In [None]:
results = DB.search("A woman in yellow shirt is putting trash into the bin", 5000) # Nên lấy nhiều

In [None]:
clear_output()
visualize(get_images(results.matches[:50], drawed=True)) # Kha nang la L02_V016,14190 :))

In [None]:
DB2 = VectorDB(workspace="DB2")
DB2.index(results.matches)

In [None]:
results2 = DB2.search("green trash can with red lid", 1000)

In [None]:
clear_output()
visualize(get_images(results2.matches[:50], drawed=True))

In [None]:
DB3 = VectorDB("DB3")
DB3.index(results2.matches)

In [None]:
results3 = DB3.search("A bag is being put in the green trash bin with a red lid by a woman in yellow shirt", 500) 

In [None]:
visualize(get_images(result_matches=results3.matches[:50], drawed=True))

In [None]:
DB4 = VectorDB()
DB4.index(results3.matches)

In [None]:
results4 = DB4.search("A woman in yellow shirt is holding a bag of trash", 100) 


In [None]:
visualize(get_images(result_matches=results4.matches[:50], drawed=True))