<a href="https://colab.research.google.com/github/tonykipkemboi/milvus-youtube-notebook/blob/main/milvus.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transcript-based Video Retrieval System using Milvus, YouTube API, and Transformers

## Introduction

The essence is to utilize the textual data from YouTube video transcripts, transform them into vector embeddings, and store these embeddings in Milvus for efficient similarity search.

In [3]:
! pip install --q --upgrade google-api-python-client google-auth-httplib2 google-auth-oauthlib youtube_transcript_api pymilvus milvus transformers

In [4]:
from googleapiclient.discovery import build
from youtube_transcript_api import YouTubeTranscriptApi
from functools import lru_cache

## Gathering Data: YouTube API

The first step is to gather the transcripts of the videos from a specific YouTube playlist. For this project, I chose the "*Vector Database Fundamentals*" playlist from the Zilliz channel because it has shorter transcripts.

In [74]:
youtube = build('youtube', 'v3', developerKey='YOUR-KEY-HERE')

In [6]:
@lru_cache(maxsize=50)
def get_channel_id(channel_title):
    request = youtube.search().list(
        q=channel_title,
        type='channel',
        part='id',
        maxResults=1
    )
    response = request.execute()
    if response['items']:
        return response['items'][0]['id']['channelId']
    return None

In [7]:
@lru_cache(maxsize=50)
def fetch_playlists(channel_title, topic):
    channel_id = get_channel_id(channel_title)
    if not channel_id:
        print(f"No channel found for title: {channel_title}")
        return None

    request = youtube.search().list(
        q=topic,
        type='playlist',
        channelId=channel_id,
        part='snippet',
        maxResults=5
    )
    response = request.execute()
    playlist = next((item for item in response['items'] if item['snippet']['title'] == topic), None)
    return playlist

In [8]:
# Search specific channel title and topic
playlist = fetch_playlists("Zilliz", "Vector Database Fundamentals")
playlist

{'kind': 'youtube#searchResult',
 'etag': 'lWJxKqvU9lcbxknnk96EQUfY9nw',
 'id': {'kind': 'youtube#playlist',
  'playlistId': 'PLPg7_faNDlT6wXMi2vfG0zJ6pK-gq6KE8'},
 'snippet': {'publishedAt': '2023-07-21T03:04:19Z',
  'channelId': 'UCMCo_F7pKjMHBlfyxwOPw-g',
  'title': 'Vector Database Fundamentals',
  'description': 'Learn the basics of vector databases, indexes and developer tooling.',
  'thumbnails': {'default': {'url': 'https://i.ytimg.com/vi/fhzDrXCpIRQ/default.jpg',
    'width': 120,
    'height': 90},
   'medium': {'url': 'https://i.ytimg.com/vi/fhzDrXCpIRQ/mqdefault.jpg',
    'width': 320,
    'height': 180},
   'high': {'url': 'https://i.ytimg.com/vi/fhzDrXCpIRQ/hqdefault.jpg',
    'width': 480,
    'height': 360}},
  'channelTitle': 'Zilliz',
  'liveBroadcastContent': 'none',
  'publishTime': '2023-07-21T03:04:19Z'}}

In [9]:
@lru_cache(maxsize=50)
def fetch_video_ids(playlist):
    request = youtube.playlistItems().list(
        playlistId=playlist,
        maxResults=15,
        part="snippet"
    )
    response = request.execute()
    video_ids = [item['snippet']['resourceId']['videoId'] for item in response['items']]
    return video_ids

In [10]:
all_video_ids = fetch_video_ids(playlist['id']['playlistId'])
all_video_ids

['fhzDrXCpIRQ',
 'v1AiTcwZaRY',
 'VOerTAir9SU',
 '5Fhz6W6-E_8',
 'm3WwttSBDvo',
 'o-eE-HZzUQY',
 'Yg64cgeNhkE',
 '7M798u_J3rg']

## Fetching Transcripts: YouTube Transcript API

With the video IDs in hand, the next quest was to fetch the transcripts. However, a snag I hit, while testing other playlists, was that some videos didn’t have transcripts. Haven't explored that rabbithole enough atm.

In [11]:
@lru_cache(maxsize=50)
def get_transcripts(video_ids: tuple):
    transcripts = {}
    for video_id in video_ids:
        try:
            transcript = YouTubeTranscriptApi.get_transcript(video_id)
            transcripts[video_id] = transcript
        except:
            pass  # TO-DO: Handle videos without transcripts
    return transcripts

In [12]:
# Convert video_ids lists to tuples before passing to get_transcripts
all_transcripts = {playlist['id']['playlistId']: get_transcripts(tuple(all_video_ids))}

In [13]:
all_transcripts['PLPg7_faNDlT6wXMi2vfG0zJ6pK-gq6KE8']['fhzDrXCpIRQ'][0]

{'text': "hey guys it's yujin welcome to the first",
 'start': 5.6,
 'duration': 5.68}

In [15]:
playlist_id = list(all_transcripts.keys())[0]
transcripts = all_transcripts[playlist_id]

if transcripts:
    # Get the first video_id and its transcript
    video_id, transcript = next(iter(transcripts.items()))

    # Join and print the transcript text
    text = " ".join([entry['text'] for entry in transcript])
    display(text)
else:
    print("No transcripts available for this playlist.")

"hey guys it's yujin welcome to the first installment in our Vector database fundamentals Series today we're going to be talking about approximate nearest neighbors oh yeah or annoy there are three things that you need to know about annoy it was invented at Spotify it creates a binary tree out of your data set and it can be used as a memory efficient and computationally efficient Vector index so the way it works is you start out with your two data your two data points in your data set and then you draw the hyperplane that separates them exactly in the middle and that creates the first split of your binary tree then you do this for each Leaf node of your binary tree and now you have four and you keep doing that until you reach a predetermined hyper parameter perhaps the number of data points within a region so when you query this gives you a computationally and memory efficient query because you are looking through a tree and then doing math on a much smaller subset of data"

## Text Embedding: Hugging Face Transformers

Chose this model (**ember-v1**) from HF: no special reason other than it is top 3 atm.
- https://huggingface.co/llmrails/ember-v1
- DIMENSION: 1024
- SEQ. LEN.: 512
- AVG.:63.54

---
### Limitation
This model exclusively caters to English texts, and any lengthy texts will be truncated to a maximum of 512 tokens.






In [None]:
! pip install grpcio==1.58.0

In [1]:
from transformers import AutoTokenizer, AutoModel
from milvus import default_server
from pymilvus import (
    connections, utility, FieldSchema,
    DataType, CollectionSchema, Collection)
import torch

# Initialize the tokenizer and model
# https://huggingface.co/llmrails/ember-v1
tokenizer = AutoTokenizer.from_pretrained("llmrails/ember-v1")
model = AutoModel.from_pretrained("llmrails/ember-v1")

Downloading (…)okenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

In [22]:
def get_embedding(text):
    # Tokenize input text
    inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
    outputs = model(**inputs, output_hidden_states=True)
    embeddings = outputs.hidden_states[-1].mean(dim=1)

    return embeddings

## Storing Embeddings: Milvus

In [16]:
# Cleanup previous data and stop server in case it is still running.
default_server.stop()
default_server.cleanup()

# Start your milvus server
default_server.start()

# Now you could connect with localhost and the given port
# Port is defined by default_server.listen_port
connections.connect(host='127.0.0.1', port=default_server.listen_port)

# Check if the server is ready.
print(utility.get_server_version())

v2.3.1-lite


In [19]:
# Fields for the collection
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),
    FieldSchema(name="video_id", dtype=DataType.VARCHAR, max_length=255)  # metadata
]

# Create a schema for the collection
schema = CollectionSchema(fields=fields, description="Zilliz Vector DB Fundamentals")

# Create the collection
collection_name = "Zilliz_Youtube_Transcripts"
if not utility.has_collection(collection_name):
    collection = Collection(name=collection_name, schema=schema)
else:
    collection = Collection(name=collection_name)

In [20]:
if utility.has_collection(collection_name): # Ensure the collection exists
    collection = Collection(name=collection_name)
    print(collection.schema)
else:
    print(f"Collection {collection_name} does not exist.")

{'auto_id': True, 'description': 'Zilliz Vector DB Fundamentals', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}, {'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 1024}}, {'name': 'video_id', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 255}}]}


In [36]:
for video_id, transcript in transcripts.items():
    text = " ".join([entry['text'] for entry in transcript])

    # Get the embedding
    embedding = get_embedding(text)
    embedding_list = embedding.detach().cpu().numpy().flatten().tolist()

    # Prepare the data for insertion
    entities = {
        "embedding": embedding_list,
        "video_id": video_id
    }

    # Store the embedding and video_id in Milvus
    ids = collection.insert([entities])
    display(ids)

    # After final entity is inserted, it is best to call flush to have no growing segments left in memory
    collection.flush()

(insert count: 1, delete count: 0, upsert count: 0, timestamp: 445231942090096641, success count: 1, err count: 0)

(insert count: 1, delete count: 0, upsert count: 0, timestamp: 445231944606679041, success count: 1, err count: 0)

(insert count: 1, delete count: 0, upsert count: 0, timestamp: 445231946533437443, success count: 1, err count: 0)

(insert count: 1, delete count: 0, upsert count: 0, timestamp: 445231949548093441, success count: 1, err count: 0)

(insert count: 1, delete count: 0, upsert count: 0, timestamp: 445231952641392641, success count: 1, err count: 0)

(insert count: 1, delete count: 0, upsert count: 0, timestamp: 445231954437079041, success count: 1, err count: 0)

(insert count: 1, delete count: 0, upsert count: 0, timestamp: 445231956075479041, success count: 1, err count: 0)

(insert count: 1, delete count: 0, upsert count: 0, timestamp: 445231957858058242, success count: 1, err count: 0)

In [37]:
print(len(embedding_list))  # Should print 1024
print(embedding_list[:5])

1024
[0.041703153401613235, 0.6804148554801941, 0.7939674258232117, -0.43260687589645386, -0.34241166710853577]


In [38]:
print(collection.is_empty)  # False if data has been inserted
print(collection.num_entities)  # number of inserted records

False
24


## Indexing and Search

In [31]:
index_params = {
    "metric_type": "L2",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128}
}
collection.create_index(field_name="embedding", index_params=index_params)

Status(code=0, message=)

In [32]:
# Load the collection into memory
collection.load()

In [70]:
import numpy as np

query_embedding = get_embedding("what is Cosine Similarity?")

flattened_embedding = np.array(query_embedding.tolist()).flatten().tolist()

search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
results = collection.search(
    [flattened_embedding],
    "embedding",
    param=search_params,
    limit=1,
    output_fields=["video_id"]
)

In [53]:
print(len(flattened_embedding))  # Should print 1024

1024


In [71]:
results

['["id: 445231527974932243, distance: 98.24575805664062, entity: {\'video_id\': \'o-eE-HZzUQY\'}"]']

## Retrieving and Displaying Results

Map the results back to the original transcripts and display them.

In [72]:
# Iterate through the results
for result in results[0]:
    video_id = result.entity.get("video_id")
    # Fetch the playlist and video from your transcripts data
    for playlist_id, transcripts in all_transcripts.items():
        transcript = transcripts.get(video_id)
        if transcript:
            # You've found the transcript, now display it or do whatever you need
            text = " ".join([entry['text'] for entry in transcript])
            display(f'Transcript: {text}')
            break

"Transcript: hey guys welcome to another installment in our Vector database fundamental Series today we're going to be talking about cosine similarity cosine similarity is the angle between two vectors which I've marked here with this purple line the way that cosine similarity is calculated looks pretty complicated but what it really is um in terms of a concept is it's the dot product or the inner product which we looked at earlier and then it is divided by the magnitude of the vectors and so what this is is essentially it's normalized in our product as if every Vector had a magnitude of one and you'll look and you can see that queen and king in this example don't hide magnitudes of one but this is that's when it is best used because it is more efficient at that point so the cosine similarity here is done by doing a DOT product and then dividing by the magnitudes so in our case it's what we just did earlier the dot product 0.3 times 0.5 plus 0.9 times 0.7 and then we also have to take 

## Cleanup

Shutting down and cleaning up the Milvus server.

In [73]:
# Shut down and cleanup the milvus server.
default_server.stop()
default_server.cleanup()

## Some TODO's:

1. Error handling
2. Logging
3. Testing
4. Manage config / security
5. Code modularity
6. Performance optimization
7. Documentation
8. Handling YouTube API rate limiting