In [None]:
from datasets import load_dataset, Video
from torch.utils.data import DataLoader, IterableDataset
import webdataset as wds
import os
import base64
import json
import spacy
from spacy import displacy
import re


In [None]:
print(os.cpu_count())
num_threads = num_threads = min(32, (os.cpu_count() or 1) + 4)
print(num_threads)

In [None]:
from datasets import load_dataset

dataset = load_dataset("webdataset", data_dir="/mmfs1/project/phan/tqn/RAG-VideoReferencing/data/", streaming=True, split='train')

In [None]:
dataset

In [None]:
for idx, data in enumerate(dataset):
    video_bytes = data["mp4"]
    temp_video_path = f"temp_video_{idx}.mp4"
    with open(temp_video_path, "wb") as f:
        f.write(video_bytes)
    test = data['json']['captions']

# Process json file

In [None]:
# Step 1: Clean texts with delete substring 

def clean_text(text):
    text = text.strip().replace("\n", " ")
    # Remove filler words
    text = re.sub(r'\b(um|uh)\b', '', text, flags=re.IGNORECASE)  # remove 'um' and 'uh' as full words
    text = re.sub(r'\b(okay|ok|o\.k\.?)\b', '', text, flags=re.IGNORECASE)  # remove variations of 'okay'
    text = re.sub(' +', ' ', text)  # remove extra spaces caused by removal
    return text.strip()

def is_substring(smaller, larger):
    return smaller.lower() in larger.lower()

def process_cleaning(captions, video_path):
    print(f"original num: {len(captions)} captions")
    caption_texts = []
    for caption in captions:
        caption_texts.append({
            "video_path": video_path,
            "start": caption["start"], 
            "end": caption["end"], 
            "text": clean_text(caption["text"])
        })
    filtered_captions = []
    for i, cap_i in enumerate(caption_texts):
        is_sub = False
        for j, cap_j in enumerate(caption_texts):
            if i != j and is_substring(cap_i["text"], cap_j["text"]):
                is_sub = True
                break
        if not is_sub:
            filtered_captions.append(cap_i)
    print(f"STEP 2: Loaded {len(filtered_captions)} captions")
    return filtered_captions

filtered_captions = []
for idx, data in enumerate(dataset):
    temp_video_path = f"temp_video_{idx}.mp4"
    captions = data['json']['captions']
    filtered_captions.append(process_cleaning(captions, temp_video_path))
print(filtered_captions)

In [None]:
# Step 2: Fix overlaps between consecutive captions
def find_overlap(a, b):
    max_overlap = min(len(a), len(b))
    for i in range(max_overlap, 0, -1):
        if a[-i:].lower() == b[:i].lower():
            return i
    return 0 

def process_overlap_elimination(captions):
    merged_chunks = []
    for idx, cap in enumerate(captions):
        if not merged_chunks:
            merged_chunks.append(cap)
        else:
            prev = merged_chunks[-1]
            overlap = find_overlap(prev['text'], cap['text'])
            if overlap > 0:
                cap_text_fixed = cap['text'][overlap:]
            else:
                cap_text_fixed = cap['text']
            if cap_text_fixed:
                merged_chunks.append({
                    "video_path": cap["video_path"],
                    "start": cap["start"],
                    "end": cap["end"],
                    "text": cap_text_fixed
                })
    return merged_chunks
total_captions = []
for video in filtered_captions:
    total_captions.append(process_overlap_elimination(video))
print(total_captions)

In [None]:
# step 3: Merge all text and split using spacy
def clean_spaces(text):
    return re.sub(r'\s+', ' ', text).strip()

nlp = spacy.load("en_core_web_sm")
chunked_sentences = []
for captions in total_captions:
    full_text = " ".join(chunk["text"] for chunk in captions).strip()
    doc = nlp(full_text)
    
    # Filter out sentences that are just "okay" (case insensitive)
    sentences = []
    for sent in doc.sents:
        sent_text = clean_spaces(sent.text)
        # Skip empty or very short sentences
        if not sent_text.strip() or len(sent_text) <= 1:
            continue
            
        # Count words in the sentence (splitting by whitespace)
        words = sent_text.strip().split()
        if len(words) <= 1:
            print(f"Filtering out single-word sentence: '{sent_text}'")
            continue
        sentences.append(sent_text)
    
    # Only add sentences if we found more than one valid sentence
    if len(sentences) > 1:
        chunked_sentences.append(sentences)
    
    print(f"STEP 3: Found {len(sentences)} valid sentences after filtering")

In [None]:
for sentence in chunked_sentences[6]:
    print(sentence)

In [None]:
print(len(chunked_sentences))
print(len(total_captions))

In [None]:
# Step 6: Find new timestamps for each segmented sentence
chunked_results = []
results = []
chunk_pointer = 0
for idx, sentences in enumerate(chunked_sentences):
    results = []
    chunk_pointer = 0
    for sentence in sentences:
        print(sentence)
        sentence_length = len(sentence)
        accumulated_text = ""
        sentence_start_time = None
        sentence_end_time = None

        while chunk_pointer < len(total_captions[idx]) and len(accumulated_text) < sentence_length:
            chunk = total_captions[idx][chunk_pointer]
            if sentence_start_time is None:
                sentence_start_time = chunk["start"]
            if accumulated_text == "":
                accumulated_text = chunk["text"].strip()
            else:
                # Check if we need to add a space between chunks
                if accumulated_text[-1] != " " and chunk["text"] and chunk["text"][0] != " ":
                    accumulated_text += " " + chunk["text"]
                else:
                    accumulated_text += chunk["text"]
            sentence_end_time = chunk["end"]
            chunk_pointer += 1
        if not sentence_start_time or not sentence_end_time: continue 
        if len(sentence) < 100: continue
        results.append({
            "video_path": f"temp_video_{idx}.mp4",
            "start": sentence_start_time,
            "end": sentence_end_time,
            "sentence": sentence
        })
    chunked_results.append(results)



In [None]:
print(len(chunked_results))

In [None]:
import os
import cv2
import numpy as np
from PIL import Image
from bertopic.backend import MultiModalBackend
import umap
import hdbscan
import faiss
from sklearn.preprocessing import normalize

In [None]:
#steps 7: Extract frames
import cv2
import clip
import torch
from PIL import Image
import numpy as np
import umap
import os
import json
import glob 
from sklearn.preprocessing import normalize
from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct, Distance, VectorParams, Filter, FieldCondition, MatchValue
import uuid


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def extract_frames_between_times(video_path, start_time, end_time, frame_rate=1):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    start_sec = time_to_seconds(start_time)
    end_sec = time_to_seconds(end_time)
    
    start_frame = int(start_sec * fps)
    end_frame = int(end_sec * fps)
    
    frames = []
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    for frame_num in range(start_frame, end_frame, int(fps / frame_rate)):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

def time_to_seconds(time_str):
    h, m, s = time_str.split(':')
    return int(h) * 3600 + int(m) * 60 + float(s)

@torch.no_grad()
def get_multimodal_embeddings(docs, image_tensor, model_name='clip-ViT-B-32', batch_size=32):

    # Truncate text if too long (CLIP has a 77 token limit)
    truncated_text = ' '.join(docs.split()[:50])   # Simple truncation strategy
    
    with torch.no_grad():
        # Tokenize with truncation enabled
        text_tokens = clip.tokenize([truncated_text], truncate=True).to(device)
        
        # Process text
        text_features = model.encode_text(text_tokens)
        
        # Process images
        if not isinstance(image_tensor, list):
            image_tensor = [image_tensor]  # Convert single image to list
            
        # Preprocess all images
        processed_images = []
        for img in image_tensor:
            processed_images.append(preprocess(img))
        
        # Stack images into a batch tensor
        image_tensor = torch.stack(processed_images).to(device)
        
        # Process in batches if needed
        if len(processed_images) > batch_size:
            image_features_list = []
            for i in range(0, len(processed_images), batch_size):
                batch = image_tensor[i:i+batch_size]
                batch_features = model.encode_image(batch)
                image_features_list.append(batch_features)
            image_features = torch.cat(image_features_list, dim=0)
        else:
            image_features = model.encode_image(image_tensor)
        
        # Average image features if multiple images
        if len(processed_images) > 1:
            image_features = torch.mean(image_features, dim=0, keepdim=True)
        
        # Normalize features
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        # Combine features (average pooling)
        fused = (text_features + image_features) / 2.0
        
        return fused.cpu().numpy().squeeze()

def reduce_dimensions(embeddings, n_components=20):
    reducer = umap.UMAP(n_components=n_components, metric='cosine')
    reduced = reducer.fit_transform(embeddings)
    return reduced

# def cluster_embeddings(embeddings, min_cluster_size=5):
#     clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, metric='euclidean')
#     labels = clusterer.fit_predict(embeddings)
#     return labels

def search(index, query_embedding, top_k=5):
    normalized_query = normalize(query_embedding.reshape(1, -1), norm='l2')
    distances, indices = index.search(normalized_query, top_k)
    return distances, indices

all_embeddings = []
points = []
for idx, results in enumerate(chunked_results):
    temp_video_path = f"temp_video_{idx}.mp4"
    for result in results:
        frames = extract_frames_between_times(temp_video_path, result['start'], result['end'], frame_rate=1)
        images = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames]
        embeddings = get_multimodal_embeddings(result['sentence'], image_tensor=images)
        all_embeddings.append(embeddings)
        payload = {
            "video_path": result["video_path"],
            "start": result["start"], 
            "end": result["end"], 
            "text": result["sentence"],
        }
        points.append(PointStruct(id=str(uuid.uuid4()), vector=embeddings.tolist(), payload=payload))

reduced_embeddings = reduce_dimensions(all_embeddings, n_components = 2)
print(reduced_embeddings.shape)

# we will use all embedding and reduced embeddings

In [None]:
client = QdrantClient(
    url="https://265484ec-5f64-40ec-a619-c7c9dffc2dd9.us-east-1-0.aws.cloud.qdrant.io:6333", 
    api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.I2MgcVnOKkWmOXwFlqJqEqm6LFQIF4cjxU5up4wxwyw",
)

print(client.get_collections())

COLLECTION_NAME = "video_segments"

# Step 1: Create Qdrant Collection (if not exists)
client.recreate_collection(
    collection_name=COLLECTION_NAME,
    vectors_config=VectorParams(size=512, distance=Distance.COSINE),
)
client.upsert(collection_name=COLLECTION_NAME, points=points)

print("Uploaded to Qdrant.")

In [None]:
info = client.get_collection(COLLECTION_NAME)
print(f"Total points: {info.points_count}")

In [None]:
# Step 5: Query with TEXT ONLY
def search_with_text_query(text_query: str, top_k=5):
    with torch.no_grad():
        text_tokens = clip.tokenize([text_query]).to(device)
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=1, keepdim=True)

    # Search in Qdrant
    search_result = client.search(
        collection_name=COLLECTION_NAME,
        query_vector=text_features.cpu().numpy()[0].tolist(),
        limit=top_k,
    )

    for i, hit in enumerate(search_result):
        print(f"\nRank {i+1}:")
        print(f"Score: {hit.score}")
        print(f"Time: {hit.payload['start']} - {hit.payload['end']}")
        print(f"Subtitle: {hit.payload['text']}")
        print(f"video path: {hit.payload['video_path']}")

# 🔍 Example query
search_with_text_query("""What is regression problem for image processing""")