# How to Build a Text-Video Retrieval Engine

This notebook illustrates how to build a text-video retrieval engine from scratch using [Milvus](https://milvus.io/) and [Towhee](https://towhee.io/).


**What is Text-Video Retrieval?**

In simple words, text-video retrieval is: given a text query and a pool of candidate videos, select the video which corresponds to the text query.


**What are Milvus & Towhee?**

- Milvus is the most advanced open-source vector database built for AI applications and supports nearest neighbor embedding search across tens of millions of entries.
- Towhee is a framework that provides ETL for unstructured data using SoTA machine learning models.

We'll go through video retrieval procedures and evaluate the performance. Moreover, we managed to make the core functionality as simple as few lines of code, with which you can start hacking your own video retrieval engine.



## Preparation

### Install packages

Make sure you have installed required python packages:

| package |
| -- |
| towhee |
| towhee.models |
| pillow |
| ipython |
| gradio |

In [1]:
! python -m pip install -q towhee towhee.models pillow ipython gradio

### Prepare the data

First, we need to prepare the dataset and Milvus environment.   

[MSR-VTT (Microsoft Research Video to Text)](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) is a dataset for the open domain video captioning, which consists of 10,000 video clips.  

Download the MSR-VTT-1kA test set from [google drive](https://drive.google.com/file/d/1cuFpHiK3jV9cZDKcuGienxTg1YQeDs-w/view?usp=sharing) and unzip it, which contains just 1k videos.  
And the video captions text sentence information is in ./MSRVTT_JSFUSION_test.csv.

The data is organized as follows:
- **test_1k_compress:** 1k compressed test videos in MSR-VTT-1kA.
- **MSRVTT_JSFUSION_test.csv:** a csv file containing an ***key,vid_key,video_id,sentence***, for each video and caption text.

Let's take a quick look

In [2]:
! curl -L https://github.com/towhee-io/examples/releases/download/data/text_video_search.zip -O
! unzip -q -o text_video_search.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:--  0:00:02 --:--:--     0
100  210M  100  210M    0     0  1961k      0  0:01:49  0:01:49 --:--:-- 2072k  1937k      0  0:01:51  0:01:29  0:00:22 1913k


In [3]:
import pandas as pd
import os

raw_video_path = './test_1k_compress' # 1k test video path.
test_csv_path = './MSRVTT_JSFUSION_test.csv' # 1k video caption csv.

test_sample_csv_path = './MSRVTT_JSFUSION_test_sample.csv'

sample_num = 1000 # you can change this sample_num to be smaller, so that this notebook will be faster.
test_df = pd.read_csv(test_csv_path)
print('length of all test set is {}'.format(len(test_df)))
sample_df = test_df.sample(sample_num, random_state=42)

sample_df['video_path'] = sample_df.apply(lambda x:os.path.join(raw_video_path, x['video_id']) + '.mp4', axis=1)

sample_df.to_csv(test_sample_csv_path)
print('random sample {} examples'.format(sample_num))

df = pd.read_csv(test_sample_csv_path)

df[['video_id', 'video_path', 'sentence']].head()

length of all test set is 1000
random sample 1000 examples


Unnamed: 0,video_id,video_path,sentence
0,video7579,./test_1k_compress/video7579.mp4,a girl wearing red top and black trouser is pu...
1,video7725,./test_1k_compress/video7725.mp4,young people sit around the edges of a room cl...
2,video9258,./test_1k_compress/video9258.mp4,a person is using a phone
3,video7365,./test_1k_compress/video7365.mp4,cartoon people are eating at a restaurant
4,video8068,./test_1k_compress/video8068.mp4,a woman on a couch talks to a a man


Define some helper function to convert video to gif so that we can have a look at these video-text pairs.   

In [4]:
from IPython import display
from pathlib import Path
from towhee import pipe, ops
from PIL import Image

def display_gif(video_path_list, text_list):
    html = ''
    for video_path, text in zip(video_path_list, text_list):
        html_line = '<img src="{}"> {} <br/>'.format(video_path, text)
        html += html_line
    return display.HTML(html)

    
def convert_video2gif(video_path, output_gif_path, num_samples=16):
    p = (
        pipe.input('video_path')
        .map('video_path', 'video_frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': num_samples}))
        .output('video_frames')
    )
    frames = p(video_path).to_list()[0][0]
    imgs = [Image.fromarray(frame) for frame in frames]
    imgs[0].save(fp=output_gif_path, format='GIF', append_images=imgs[1:], save_all=True, loop=0)


def display_gifs_from_video(video_path_list, text_list, tmpdirname = './tmp_gifs'):
    Path(tmpdirname).mkdir(exist_ok=True)
    gif_path_list = []
    for video_path in video_path_list:
        video_name = str(Path(video_path).name).split('.')[0]
        gif_path = Path(tmpdirname) / (video_name + '.gif')
        convert_video2gif(video_path, gif_path)
        gif_path_list.append(gif_path)
    return display_gif(gif_path_list, text_list)

Take a look at the ground-truth video-text pairs.

In [5]:
# sample_show_df = sample_df.sample(5, random_state=42)
sample_show_df = sample_df[:5]
video_path_list = sample_show_df['video_path'].to_list()
text_list = sample_show_df['sentence'].to_list()
tmpdirname = './tmp_gifs'
display_gifs_from_video(video_path_list, text_list, tmpdirname=tmpdirname)

### Create a Milvus Collection

Before getting started, please make sure that you have started a [Milvus service](https://milvus.io/docs/install_standalone-docker.md). This notebook uses [milvus 2.2.10](https://milvus.io/docs/v2.2.x/install_standalone-docker.md) and [pymilvus 2.2.11](https://milvus.io/docs/release_notes.md#2210).

In [None]:
! python -m pip install -q pymilvus==2.2.11

Let's first create a `video retrieval` collection that uses the [L2 distance metric](https://milvus.io/docs/metric.md#Euclidean-distance-L2) and an [IVF_FLAT index](https://milvus.io/docs/index.md#IVF_FLAT).

In [6]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

connections.connect(host='127.0.0.1', port='19530')

def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='video retrieval')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2', #IP
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

In [132]:
collection = create_milvus_collection('text_video_retrieval', 512)

## Text-Video retrieval

In this section, we'll show how to build our text-video retrieval engine using Milvus. The basic idea behind text-video retrieval is the extract embeddings from videos using a Transformer network and store them in Milvus, then using another Transformer network to get text embeddings and compare with those stored in Milvus.

We use [Towhee](https://towhee.io/), a machine learning framework that allows for creating data processing pipelines. [Towhee](https://towhee.io/) also provides predefined operators which implement insert and query operation in Milvus.


### Load Video Embeddings into Milvus

We first extract embeddings from images with `CLIP4Clip` model and insert the embeddings into Milvus for indexing. Towhee provides a [method-chaining style API](https://towhee.readthedocs.io/en/main/index.html) so that users can assemble a data processing pipeline with operators.   

[CLIP4Clip](https://arxiv.org/abs/2104.08860) is a video-text retrieval model based on [CLIP (ViT-B)](https://github.com/openai/CLIP). The [towhee clip4clip operator](https://towhee.io/video-text-embedding/clip4clip) with pretrained weights can easily extract video embedding and text embedding by a few codes.

![image-4.png](attachment:image-4.png)

Before you start running the clip4clip operator, you should make sure you have make [git](https://git-scm.com/) and [git-lfs](https://git-lfs.github.com/) installed.    

For git(If you have installed, just skip it):
```
sudo apt-get install git
```
For git-lfs(You must install it for downloading checkpoint weights on backend):
```
sudo apt-get install git-lfs
git lfs install
```


In [112]:
%%time
import os
from towhee import ops, pipe, register
from towhee.operator import PyOperator


def read_csv(csv_file):
    import csv
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['video_id'][len('video'):]), line['video_path']


dc = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'video_path'), read_csv)
    .map('video_path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}))
    .map('frames', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='video', device='cuda:1'))
    .map(('video_id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='text_video_retrieval'))
    .output('video_id')
)

CPU times: user 2.48 s, sys: 266 ms, total: 2.74 s
Wall time: 2.74 s


In [134]:
dc(test_sample_csv_path)
collection.load()

In [115]:
print('Total number of inserted data is {}.'.format(collection.num_entities))

Total number of inserted data is 0.


Here is detailed explanation for each line of the code:

- `read_csv(test_sample_csv_path)`: read tabular data from csv file;


- `ops.video_decode.ffmpeg`: subsample the video uniformly, and then get a list of images in the video, which are the input of the clip4clip model;

- `ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='video')`: extract embedding feature from the images subsampled from video, and then mean pool them in the temporal dimension, which repre.

- `ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='text_video_retrieval')`: insert video embedding features in to Milvus;


In [135]:
%%time

def read_csv(csv_file):
    import csv
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield line['video_id'], line['sentence']

dc_search = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'sentence'), read_csv)
    .map('sentence', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='text', device='cuda:1'))
    .map('vec', 'top10_raw_res', 
         ops.ann_search.milvus_client(
             host='127.0.0.1', port='19530', collection_name='text_video_retrieval', limit=10)
        )
    .map('top10_raw_res', ('top1', 'top5', 'top10'), lambda x: (x[:1], x[:5], x[:10]))
    .map('video_id', 'ground_truth', lambda x: x[len('video'):])
    .output('video_id', 'sentence', 'ground_truth', 'top1', 'top5', 'top10')
)

CPU times: user 2.53 s, sys: 235 ms, total: 2.77 s
Wall time: 2.76 s


In [137]:
from towhee import DataCollection

ret = DataCollection(dc_search(test_sample_csv_path))
ret.show()

video_id,sentence,ground_truth,top1,top5,top10
video7579,a girl wearing red top and black trouser is putting a sweater on a dog,7579,"[[7579, 1.4153318405151367]] len=1","[[7579, 1.4153318405151367],[9969, 1.4798685312271118],[8837, 1.4901210069656372],[9347, 1.4925624132156372],...] len=5","[[7579, 1.4153318405151367],[9969, 1.4798685312271118],[8837, 1.4901210069656372],[9347, 1.4925624132156372],...] len=10"
video7725,young people sit around the edges of a room clapping and raising their arms while others dance in the center during a party,7725,"[[7725, 1.360759973526001]] len=1","[[7725, 1.360759973526001],[8014, 1.4908323287963867],[8339, 1.491390585899353],[8442, 1.503340721130371],...] len=5","[[7725, 1.360759973526001],[8014, 1.4908323287963867],[8339, 1.491390585899353],[8442, 1.503340721130371],...] len=10"
video9258,a person is using a phone,9258,"[[9258, 1.4011759757995605]] len=1","[[9258, 1.4011759757995605],[9257, 1.421643614768982],[9697, 1.4404571056365967],[7910, 1.4957678318023682],...] len=5","[[9258, 1.4011759757995605],[9257, 1.421643614768982],[9697, 1.4404571056365967],[7910, 1.4957678318023682],...] len=10"
video7365,cartoon people are eating at a restaurant,7365,"[[7365, 1.4048030376434326]] len=1","[[7365, 1.4048030376434326],[8781, 1.460750699043274],[9537, 1.4721591472625732],[7831, 1.5040078163146973],...] len=5","[[7365, 1.4048030376434326],[8781, 1.460750699043274],[9537, 1.4721591472625732],[7831, 1.5040078163146973],...] len=10"
video8068,a woman on a couch talks to a a man,8068,"[[7162, 1.4739404916763306]] len=1","[[7162, 1.4739404916763306],[8304, 1.478574514389038],[8068, 1.4937106370925903],[7724, 1.4958460330963135],...] len=5","[[7162, 1.4739404916763306],[8304, 1.478574514389038],[8068, 1.4937106370925903],[7724, 1.4958460330963135],...] len=10"


## Evaluation

We have finished the core functionality of the text-video retrieval engine. However, we don't know whether it achieves a reasonable performance. We need to evaluate the retrieval engine against the ground truth so that we know if there is any room to improve it.

In this section, we'll evaluate the strength of our text-video retrieval using recall@topk:   
`Recall@topk` is the proportion of relevant items found in the top-k recommendations. Suppose that we computed recall at 10 and found it is 40% in our top-10 recommendation system. This means that 40% of the total number of the relevant items appear in the top-k results.

In [138]:
def mean_hit_ratio(actual, *predicteds):
    rets = []
    for predicted in predicteds:
        ratios = []
        for act, pre in zip(actual, predicted):
            hit_num = len(set(act) & set(pre))
            ratios.append(hit_num / len(act))
        rets.append(sum(ratios) / len(ratios))
    return rets

def get_label_from_raw_data(data):
    ret = []
    for item in data:
        ret.append(item[0])
    return ret


ev = (
    pipe.input('dc_data')
    .flat_map('dc_data', 'data', lambda x: x)
    .map('data', ('ground_truth', 'top1', 'top5', 'top10'), 
         lambda x: ([int(x.ground_truth)], 
                    get_label_from_raw_data(x.top1), 
                    get_label_from_raw_data(x.top5), 
                    get_label_from_raw_data(x.top10))
        )
    .window_all(('ground_truth', 'top1', 'top5', 'top10'), ('top1_mean_hit_ratio', 'top5_mean_hit_ratio', 'top10_mean_hit_ratio'), mean_hit_ratio)
    .output('top1_mean_hit_ratio', 'top5_mean_hit_ratio', 'top10_mean_hit_ratio')
)

DataCollection(ev(ret)).show()

top1_mean_hit_ratio,top5_mean_hit_ratio,top10_mean_hit_ratio
0.421,0.712,0.813


This result is almost identical to the recall metrics represented in the paper. You can find more detail about metrics in [paperwithcode](https://paperswithcode.com/paper/clip4clip-an-empirical-study-of-clip-for-end/review/?hl=30331).

## Release a Showcase

We've learnt how to build a reverse video search engine. Now it's time to add some interface and release a showcase. Towhee provides `towhee.api()` to wrap the data processing pipeline as a function with `.as_function()`. So we can build a quick demo with this `milvus_search_function` with [Gradio](https://gradio.app/).

In [146]:
import gradio

show_num = 3

milvus_search_pipe = (
    pipe.input('sentence')
    .map('sentence', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='text', device='cpu'))
    .map('vec', 'rows', 
         ops.ann_search.milvus_client(
             host='127.0.0.1', port='19530', collection_name='text_video_retrieval', limit=show_num)
    )
    .map('rows', 'videos_path',
         lambda rows: (os.path.join(raw_video_path, 'video' + str(r[0]) + '.mp4') for r in rows))
    .output('videos_path')
)


def milvus_search_function(text):
    return milvus_search_pipe(text).to_list()[0][0]

# print(milvus_search_function('a girl wearing red top and black trouser is putting a sweater on a dog'))


interface = gradio.Interface(milvus_search_function, 
                             inputs=[gradio.Textbox()],
                             outputs=[gradio.Video(format='mp4') for _ in range(show_num)]
                            )

interface.launch(inline=True, share=True)


Running on local URL:  http://127.0.0.1:7866
Running on public URL: https://e56ab44b-743a-4430.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces




In [None]:
# import shutil
# shutil.rmtree(tmpdirname)