In [None]:
# Install vector database
# ! pip install vectordb

## Import modules

In [1]:
import os
import numpy as np
from tqdm.notebook import trange, tqdm
from PIL import Image, ImageFont, ImageDraw 
import torch
import clip
import json as js
from docarray import DocList, BaseDoc
from docarray.typing import NdArray
import numpy as np
from vectordb import InMemoryExactNNVectorDB, HNSWVectorDB
from IPython.display import clear_output
from natsort import natsorted
import pandas as pd
from typing import List
import webbrowser
import shutil
import random


## Constants

In [2]:
MODEL = "ViT-B/32"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

METADATA_PATH = "../data/metadata/"
KEYFRAME_PATH = "../data/keyframes/"
FEATURE_PATH = "../data/features/"
MAP_KEYFRAMES = "../data/map-keyframes/"
VIDEOS_PATH = "../data/videos/"
SCRIPT_PATH = "../data/scripts/"

WORKSPACE = "./vectordb"

cuda


In [None]:
# verify scripts
print(len(os.listdir(VIDEOS_PATH)))
print(len(os.listdir(SCRIPT_PATH)))
print(len(os.listdir(VIDEOS_PATH)) == len(os.listdir(SCRIPT_PATH)))

## Re-formating Dataset

#### Get video names

In [3]:
LEN_OF_KEYFRAME_NAME = 4

In [4]:
video_names = [name for name in os.listdir(KEYFRAME_PATH) if name != ".gitkeep"]
print(video_names)

['L01_V001', 'L01_V002', 'L01_V003', 'L01_V004', 'L01_V005', 'L01_V006', 'L01_V007', 'L01_V008', 'L01_V009', 'L01_V010', 'L01_V011', 'L01_V012', 'L01_V013', 'L01_V014', 'L01_V015', 'L01_V016', 'L01_V017', 'L01_V018', 'L01_V019', 'L01_V020', 'L01_V021', 'L01_V022', 'L01_V023', 'L01_V024', 'L01_V025', 'L01_V026', 'L01_V027', 'L01_V028', 'L01_V029', 'L01_V030', 'L01_V031', 'L02_V001', 'L02_V002', 'L02_V003', 'L02_V004', 'L02_V005', 'L02_V006', 'L02_V007', 'L02_V008', 'L02_V009', 'L02_V010', 'L02_V011', 'L02_V012', 'L02_V013', 'L02_V014', 'L02_V015', 'L02_V016', 'L02_V017', 'L02_V018', 'L02_V019', 'L02_V020', 'L02_V021', 'L02_V022', 'L02_V023', 'L02_V024', 'L02_V025', 'L02_V026', 'L02_V027', 'L02_V028', 'L02_V029', 'L02_V030', 'L03_V001', 'L03_V002', 'L03_V003', 'L03_V004', 'L03_V005', 'L03_V006', 'L03_V007', 'L03_V008', 'L03_V009', 'L03_V010', 'L03_V011', 'L03_V012', 'L03_V013', 'L03_V014', 'L03_V015', 'L03_V016', 'L03_V017', 'L03_V018', 'L03_V019', 'L03_V020', 'L03_V021', 'L03_V022', 'L0

In [5]:
for name in video_names:
    keyframes = [path for path in os.listdir(os.path.join(KEYFRAME_PATH, name))]
    for kf in keyframes:
        img_name = kf.split(".")[0]
        if len(img_name) != LEN_OF_KEYFRAME_NAME:
            changed_path = os.path.join(KEYFRAME_PATH, name, img_name.zfill(4) + ".jpg")
            old_path = os.path.join(KEYFRAME_PATH, name, kf)
            print(f"Change {old_path} to {changed_path}")
            os.rename(old_path, changed_path)
            

## Text embedding

In [6]:
class TextEmbedding():
  def __init__(self):
    self.device = DEVICE
    self.model, _ = clip.load(MODEL, device=self.device)

  def __call__(self, text: str) -> np.ndarray:
    text_inputs = clip.tokenize([text]).to(self.device)
    with torch.no_grad():
        text_feature = self.model.encode_text(text_inputs)[0]
    return text_feature.detach().cpu().numpy()
  
  def __call__(self, texts) -> np.ndarray:
    text_inputs = clip.tokenize(texts).to(self.device)
    with torch.no_grad():
        text_feature = self.model.encode_text(text_inputs)[0]
    return text_feature.detach().cpu().numpy()


In [7]:
text_embedding = TextEmbedding()

In [8]:
query = "A blue sky in the background"
query_feat = text_embedding(query)
print(len(query_feat))

512


In [9]:
querys = ["A blue sky in the background", "People hangout at the beach", "Birds are flying in the sky"]
querys_feat = text_embedding(querys)
print(len(querys_feat))

512


## Vector Database

### Frame Document Class

In [10]:
class FrameDoc(BaseDoc):
  embedding: NdArray[512]
  video_name = ""
  image_path = ""
  keyframe_id = 0
  actual_idx = 0
  metadata = {}
  
  def __str__(self):
    return f"""
            Video name: {self.video_name}
            Image path: {self.image_path}
            Keyframe Id: {self.keyframe_id}
            Actual keyframe idx: {self.actual_idx}
            Metadata: {self.metadata}
          """

##  Database Handler

In [None]:
class VectorDB: 
    text_embedding = TextEmbedding()
    workspace = os.getcwd()
    method = "ANN"
    def __init__(self, method="ANN"):
        # Check if parent workspace exists
        if not os.path.isdir(WORKSPACE):
            os.mkdir(WORKSPACE, 0o666)
        # Create new workspae
        exits = [int(name.rsplit("_")[1]) for name in os.listdir(WORKSPACE)]
        while True:
            id = random.getrandbits(128)
            if id not in exits:
                self.workspace = os.path.join(self.workspace, WORKSPACE, "DB_" + str(id))
                break
            
        self.method = method
        #   Approximate Nearest Neighbour based on HNSW algorithm
        if method == "ANN":
            self.DB = HNSWVectorDB[FrameDoc](workspace=self.workspace)
            
        # Exhaustive search on the embeddings
        else:
            self.DB = InMemoryExactNNVectorDB[FrameDoc](workspace=self.workspace)
        
        
    def index(self, doc_list: List[FrameDoc]):    
        # Index database
        self.DB.index(inputs=DocList[FrameDoc](doc_list))
        
    def search(self, query_text: str, topk=100):
        query_doc = FrameDoc(embedding=self.text_embedding(query_text))
        return self.DB.search(inputs=DocList[FrameDoc]([query_doc]), limit=topk)[0].matches
    
    def delete(self, del_doc_list: List[FrameDoc]):
        self.DB.delete(docs=DocList[FrameDoc](del_doc_list))
        

### Needed functions 

### Get all features files

In [None]:
def get_all_feats():
    return [os.path.join(FEATURE_PATH, file) for file in os.listdir(FEATURE_PATH) if file.endswith(".npy")]

In [None]:
all_feat_files = get_all_feats()
print(all_feat_files)
print(len(all_feat_files))

### Create all the Docs

In [None]:
def get_all_docs(npy_files):
    doc_list = []
    for feat_npy in npy_files:
        video_name = feat_npy[feat_npy.find("L"):].split('.')[0]
        feats_arr = np.load(os.path.join(feat_npy))
        # Load metadata
        metadata = {}
        with open(os.path.join(METADATA_PATH, video_name + ".json")) as meta_f:
            metadata = js.load(meta_f)
            map_kf = pd.read_csv(os.path.join(MAP_KEYFRAMES, video_name + ".csv"), usecols=["pts_time", "frame_idx"])
            metadata =  {key: metadata[key] for key in ["publish_date", "watch_url"]}
            for frame_idx, feat in enumerate(feats_arr):
                image_path = os.path.join(KEYFRAME_PATH, video_name, f"{frame_idx + 1:04d}.jpg")
                frame_metadata = metadata.copy()
                frame_metadata["watch_url"] = frame_metadata["watch_url"]  + "&t=" + str(map_kf["pts_time"][frame_idx]) + "s"
                actual_idx=map_kf["frame_idx"][frame_idx]
                doc_list.append(
                                FrameDoc(
                                        embedding=feat, 
                                        video_name=video_name, 
                                        image_path=image_path, 
                                        keyframe_id=frame_idx+1, 
                                        actual_idx=actual_idx, 
                                        metadata=frame_metadata
                                        )
                                )
                
    return doc_list

In [None]:
doc_list = get_all_docs(all_feat_files)

In [None]:
print(len(doc_list))

In [None]:
print(doc_list[100])

### Open video at specific times

In [None]:
def open_video(doc: FrameDoc):
    webbrowser.open(doc.metadata["watch_url"])

## Visualization functions

In [None]:
def get_images(results):
    images = []
    for i, res in enumerate(results):
        img = Image.open(res.image_path)
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype("arial.ttf", 50)
        draw.text(xy=(5, 5), text=f"{i}, {res.video_name}, {res.actual_idx}", align="left", fill=(255,0,0,255), font=font)
        images.append(img)
    return images

In [None]:
def visualize(imgs: List[Image.Image]) -> None:
    rows = len(imgs) // 2 # see more clearly
    if not rows:
        rows += 1
    cols = len(imgs) // rows
    if rows * cols < len(imgs):
        rows += 1
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        
        grid.paste(img, box=(i % cols * w, i // cols * h))

    display(grid)

## DEMO

### Create DB

In [None]:
DB = VectorDB()
DB.index(doc_list)

### Query

#### Query text

Đoạn video về một chiếc xe 7 chỗ biển số màu vàng hiệu Toyota tông vào gốc cây bên lề đường. Bên cạnh chiếc xe có biển báo giao nhau với đường không ưu tiên. Trong đoạn video có cảnh một người bảo vệ áo màu xanh đang nằm trên chiếc xe máy.

### Filter 1

In [None]:
results1 = DB.search("Three men are running. On the left of the frame is a row of small green trees next to scattered items, and on the right is a carpet of grass.", 1000) # Nên lấy nhiều

In [None]:
clear_output()
images = get_images(results1[:50])

In [None]:
visualize(images)

In [None]:
open_video(results1[7])

### Filter 2

In [None]:
DB2 = VectorDB()
DB2.index(results1)

In [None]:
results2 = DB2.search("A bag is on the road", 500)

In [None]:
clear_output()
visualize(get_images(results2[:50]))

In [None]:
open_video(results2[2])

In [None]:
print(results2[39])

In [None]:
from utils import *
images = 