In [1]:
!pip install pyngrok flask pyspark spark-nlp==5.3.0 faiss-cpu

Collecting pyngrok
  Downloading pyngrok-7.2.8-py3-none-any.whl.metadata (10 kB)
Collecting spark-nlp==5.3.0
  Downloading spark_nlp-5.3.0-py2.py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.1/57.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Downloading spark_nlp-5.3.0-py2.py3-none-any.whl (564 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.8/564.8 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyngrok-7.2.8-py3-none-any.whl (25 kB)
Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl (31.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m93.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spark-nlp, pyngrok, faiss-cpu
Successfully installed faiss-cpu-1.11.0 pyngrok-7.2.8 spark-nlp-5.3.0


In [1]:
import os
# use your own ngrok authtoken
os.environ['NGROK_AUTHTOKEN'] = ""

In [None]:
import os
import time
import numpy as np
import faiss
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

from sparknlp.base import DocumentAssembler
from sparknlp.annotator import BertSentenceEmbeddings
from pyspark.ml import Pipeline

# --- Flask and ngrok imports (from service script) ---
from flask import Flask, request, jsonify
from pyngrok import ngrok
import threading
import logging

# --- Additional imports for CLIP model and Video_embed integration ---
import torch
import pickle
from transformers import CLIPModel, CLIPProcessor

# Suppress excessive pyngrok logging
pyngrok_logger = logging.getLogger("pyngrok")
pyngrok_logger.setLevel(logging.WARNING)

# --- Mount Google Drive --- (Moved here for early execution if needed by Flask later)
from google.colab import drive
try:
    drive.mount('/content/drive', force_remount=True)
    print("Google Drive mounted successfully.")
except Exception as e:
    print(f"Error mounting Google Drive: {e}")

# --- CONFIGURATION ---
# QA Data Paths (from build_qa_faiss.py)
QA_BASE_PATH = '/content/drive/My Drive/QA_dataset'
EMBEDDINGS_PARQUET_PATH_QA = os.path.join(QA_BASE_PATH, "qa_combined_embeddings.parquet")
FAISS_INDEX_PATH_QA = os.path.join(QA_BASE_PATH, "qa_combined_embeddings.index")

# YouTube transcripts and titles embeddings Data Paths (from build_youtube_faiss.py)
YT_BASE_PATH = '/content/drive/My Drive/Youtube_100M_dataset_v3'
EMBEDDINGS_PARQUET_PATH_YT = os.path.join(YT_BASE_PATH, "combined_video_embeddings.parquet")
FAISS_INDEX_PATH_YT = os.path.join(YT_BASE_PATH, "combined_video_embeddings.index")

# CLIP Model and Video_embed Data Paths
CLIP_PROJECT_DIR = '/content/drive/MyDrive/video_embeddings_project' # Base directory from Video_embed.ipynb
CLIP_FAISS_INDEX_PATH = os.path.join(CLIP_PROJECT_DIR, "video_embeddings.index")
CLIP_METADATA_PATH = os.path.join(CLIP_PROJECT_DIR, "metadata.pkl")

# New Educational Video Paths
CLIP_EDU_PROJECT_DIR = '/content/drive/MyDrive/video_embeddings_project_edu'
CLIP_EDU_FAISS_INDEX_PATH = os.path.join(CLIP_EDU_PROJECT_DIR, "video_embeddings.index")
CLIP_EDU_METADATA_PATH = os.path.join(CLIP_EDU_PROJECT_DIR, "metadata.pkl")

CLIP_MODEL_NAME = "openai/clip-vit-base-patch16"

# --- Global variables for service (expanded from service script) ---
spark = None
query_embedding_pipeline_model = None
faiss_index_qa = None
metadata_pd_df_qa = None
faiss_index_yt = None
metadata_pd_df_yt = None
ngrok_public_url = None

# Global variables for CLIP integration
clip_model = None
clip_processor = None
faiss_index_clip = None
metadata_clip = None # This will be a list of dicts
clip_device = None

# --- Helper functions for merging FAISS indices and metadata ---
def merge_and_deduplicate_faiss_indices(existing_index, new_index):
    """
    Merge two FAISS indices and remove duplicate vectors, using batch processing for better performance
    """
    if existing_index is None:
        return new_index, range(new_index.ntotal)

    if new_index is None:
        return existing_index, range(existing_index.ntotal)

    print("Starting to merge and deduplicate FAISS indices...")

    # Ensure they have the same dimensions
    assert existing_index.d == new_index.d, "The two indices have mismatched vector dimensions"

    # Create merged index
    merged_index = faiss.IndexFlatIP(existing_index.d)

    # Add all vectors from the existing index
    existing_vectors = np.zeros((existing_index.ntotal, existing_index.d), dtype=np.float32)
    for i in range(existing_index.ntotal):
        vector = faiss.vector_to_array(existing_index.get_vector(i)).reshape(1, -1)
        existing_vectors[i] = vector

    merged_index.add(existing_vectors)
    print(f"Added {existing_index.ntotal} vectors from existing index")

    # Create a mapping for valid indices of new vectors
    valid_indices = []

    # Process new index vectors in batches instead of one at a time
    batch_size = 100  # Number of vectors to process per batch
    total_new_vectors = new_index.ntotal
    total_batches = (total_new_vectors + batch_size - 1) // batch_size  # Ceiling division

    total_added = 0

    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, total_new_vectors)
        batch_size_actual = end_idx - start_idx

        # Extract vectors for this batch
        batch_vectors = np.zeros((batch_size_actual, new_index.d), dtype=np.float32)
        for i in range(batch_size_actual):
            vector_idx = start_idx + i
            vector = faiss.vector_to_array(new_index.get_vector(vector_idx)).reshape(1, -1)
            batch_vectors[i] = vector

        # Perform search on the entire batch of vectors
        distances, _ = merged_index.search(batch_vectors, 1)

        # Process results and add non-duplicate vectors
        new_vectors_to_add = []
        for i in range(batch_size_actual):
            if distances[i][0] < 0.99:  # If similarity is below threshold, consider it a new vector
                new_vectors_to_add.append(batch_vectors[i].reshape(1, -1))
                valid_indices.append(start_idx + i)

        # Add non-duplicate vectors in batch
        if new_vectors_to_add:
            new_vectors_array = np.vstack(new_vectors_to_add)
            merged_index.add(new_vectors_array)
            total_added += len(new_vectors_to_add)

        # Print progress
        if (batch_idx + 1) % 10 == 0 or batch_idx == total_batches - 1:
            progress = (batch_idx + 1) / total_batches * 100
            print(f"Processing progress: {progress:.1f}% ({batch_idx + 1}/{total_batches} batches), added {total_added} non-duplicate vectors")

    print(f"Added {total_added} non-duplicate vectors from new index")
    print(f"Merged index now contains {merged_index.ntotal} vectors total")

    return merged_index, valid_indices

def merge_and_deduplicate_metadata(existing_metadata, new_metadata, valid_indices):
    """
    Merge metadata and keep only the new metadata corresponding to valid indices
    """
    if existing_metadata is None:
        if new_metadata is None:
            return []
        return [new_metadata[i] for i in valid_indices]

    if new_metadata is None:
        return existing_metadata

    # Only add new metadata corresponding to valid indices
    filtered_new_metadata = [new_metadata[i] for i in valid_indices]

    # Check for duplicate video IDs
    existing_video_ids = set(item.get('video_id', '') for item in existing_metadata if item.get('video_id'))

    # Remove metadata with duplicate video IDs
    unique_new_metadata = []
    duplicate_count = 0

    for item in filtered_new_metadata:
        video_id = item.get('video_id', '')
        if video_id and video_id in existing_video_ids:
            duplicate_count += 1
            continue
        unique_new_metadata.append(item)
        if video_id:
            existing_video_ids.add(video_id)

    print(f"Metadata deduplication: Ignored {duplicate_count} entries with duplicate video IDs")

    merged_metadata = existing_metadata + unique_new_metadata
    print(f"Merged metadata now contains {len(merged_metadata)} entries total")

    return merged_metadata

# --- INITIALIZATION (Functions from build_qa_faiss_version2.py, called at startup) ---
print("Initializing Spark session...")
try:
    spark = SparkSession.builder \
        .appName("Combined_FAISS_Service_Spark") \
        .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.3.0") \
        .config("spark.driver.memory", "16G") \
        .config("spark.driver.maxResultSize", "8g") \
        .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
        .master("local[*]") \
        .getOrCreate()
    print("Spark session started.")
except Exception as e:
    print(f"Error initializing Spark session: {e}")
    exit() # Critical failure

def build_query_pipeline_global(): # Renamed to avoid conflict if other local funcs exist
    document_query_template = DocumentAssembler().setInputCol("text").setOutputCol("document")
    embeddings_model_template = BertSentenceEmbeddings.pretrained("sent_small_bert_L2_128") \
        .setInputCols(["document"]).setOutputCol("embedding").setCaseSensitive(False)
    query_pipeline_template = Pipeline(stages=[document_query_template, embeddings_model_template])
    dummy_df_for_fit = spark.createDataFrame([("dummy query text",)], ["text"])
    return query_pipeline_template.fit(dummy_df_for_fit)

print("\nPreparing query embedding generator...")
try:
    query_embedding_pipeline_model = build_query_pipeline_global()
    print("Query embedding generator ready.")
except Exception as e:
    print(f"Error: Could not prepare query embedding generator: {e}")
    if spark: spark.stop()
    exit() # Critical failure

def load_metadata_global():
    global metadata_pd_df_qa, metadata_pd_df_yt # Explicitly declare we are modifying globals
    print(f"\nLoading QA metadata from: {EMBEDDINGS_PARQUET_PATH_QA}")
    print(f"Loading YT metadata from: {EMBEDDINGS_PARQUET_PATH_YT}")
    current_arrow_status_qa = spark.conf.get("spark.sql.execution.arrow.pyspark.enabled")
    current_arrow_status_yt = spark.conf.get("spark.sql.execution.arrow.pyspark.enabled") # Redundant but for clarity per df
    try:
        spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
        print("Temporarily disabled Arrow for QA metadata toPandas.")
        loaded_metadata_spark_df_qa = spark.read.parquet(EMBEDDINGS_PARQUET_PATH_QA)
        metadata_pd_df_qa = loaded_metadata_spark_df_qa.select("qa_id", "Question", "Answer", "Level").toPandas()
        print(f"QA metadata loaded. Records: {len(metadata_pd_df_qa)}")
        spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", current_arrow_status_qa)
        print(f"Restored Arrow to {current_arrow_status_qa} for QA.")

        spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
        print("Temporarily disabled Arrow for YT metadata toPandas.")
        loaded_metadata_spark_df_yt = spark.read.parquet(EMBEDDINGS_PARQUET_PATH_YT)
        metadata_pd_df_yt = loaded_metadata_spark_df_yt.select("video_id", "title", "transcript").toPandas()
        print(f"YT metadata loaded. Records: {len(metadata_pd_df_yt)}")
        spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", current_arrow_status_yt)
        print(f"Restored Arrow to {current_arrow_status_yt} for YT.")

    except Exception as e:
        print(f"Error: Could not load or process Parquet metadata. Error: {e}")
        # Restore arrow status in case of failure during one of the loads
        if spark:
            try:
                spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") # Best guess restore
                print("Attempted to restore Arrow to true after metadata load error.")
            except Exception as conf_e:
                 print(f"Could not restore Arrow config after error: {conf_e}")
        return False # Indicate failure
    return True # Indicate success

if not load_metadata_global():
    print("Critical Error: Metadata loading failed. Service cannot start.")
    if spark: spark.stop()
    exit()

def load_faiss_indices_global():
    global faiss_index_qa, faiss_index_yt # Explicitly declare we are modifying globals
    print(f"\nLoading QA FAISS index from: {FAISS_INDEX_PATH_QA}")
    print(f"Loading YT FAISS index from: {FAISS_INDEX_PATH_YT}")
    try:
        faiss_index_qa = faiss.read_index(FAISS_INDEX_PATH_QA)
        print(f"QA FAISS index loaded. Vectors: {faiss_index_qa.ntotal}")
        faiss_index_yt = faiss.read_index(FAISS_INDEX_PATH_YT)
        print(f"YT FAISS index loaded. Vectors: {faiss_index_yt.ntotal}")
    except Exception as e:
        print(f"Error: Could not load FAISS index files. Error: {e}")
        return False
    return True

if not load_faiss_indices_global():
    print("Critical Error: FAISS index loading failed. Service cannot start.")
    if spark: spark.stop()
    exit()

def load_clip_resources():
    """Loads the CLIP model, processor, FAISS index, and metadata from Video_embed.ipynb."""
    global clip_model, clip_processor, faiss_index_clip, metadata_clip, clip_device

    print("\\n--- Loading CLIP Model and Video_embed Resources ---")
    try:
        clip_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device for CLIP model: {clip_device}")

        print(f"Loading CLIP processor: {CLIP_MODEL_NAME}...")
        clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
        print(f"Loading CLIP model: {CLIP_MODEL_NAME}...")
        clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(clip_device)
        clip_model.eval() # Set to evaluation mode
        print("CLIP model and processor loaded successfully.")

        # 1. Load original FAISS index and metadata
        original_faiss_index = None
        original_metadata = None

        print(f"Loading original CLIP FAISS index from: {CLIP_FAISS_INDEX_PATH}...")
        if os.path.exists(CLIP_FAISS_INDEX_PATH):
            original_faiss_index = faiss.read_index(CLIP_FAISS_INDEX_PATH)
            print(f"Original CLIP FAISS index loaded. Vector count: {original_faiss_index.ntotal}")
        else:
            print(f"Warning: Original CLIP FAISS index not found at {CLIP_FAISS_INDEX_PATH}")

        print(f"Loading original CLIP metadata from: {CLIP_METADATA_PATH}...")
        if os.path.exists(CLIP_METADATA_PATH):
            with open(CLIP_METADATA_PATH, 'rb') as f:
                original_metadata = pickle.load(f)
            print(f"Original CLIP metadata loaded. Entry count: {len(original_metadata)}")
            # Basic validation of metadata structure
            if not isinstance(original_metadata, list) or (len(original_metadata) > 0 and not isinstance(original_metadata[0], dict)):
                print("Warning: Original CLIP metadata is not a list of dictionaries")
        else:
            print(f"Warning: Original CLIP metadata not found at {CLIP_METADATA_PATH}")

        # 2. Load educational video FAISS index and metadata
        edu_faiss_index = None
        edu_metadata = None

        print(f"Loading educational CLIP FAISS index from: {CLIP_EDU_FAISS_INDEX_PATH}...")
        if os.path.exists(CLIP_EDU_FAISS_INDEX_PATH):
            edu_faiss_index = faiss.read_index(CLIP_EDU_FAISS_INDEX_PATH)
            print(f"Educational CLIP FAISS index loaded. Vector count: {edu_faiss_index.ntotal}")
        else:
            print(f"Warning: Educational CLIP FAISS index not found at {CLIP_EDU_FAISS_INDEX_PATH}")

        print(f"Loading educational CLIP metadata from: {CLIP_EDU_METADATA_PATH}...")
        if os.path.exists(CLIP_EDU_METADATA_PATH):
            with open(CLIP_EDU_METADATA_PATH, 'rb') as f:
                edu_metadata = pickle.load(f)
            print(f"Educational CLIP metadata loaded. Entry count: {len(edu_metadata)}")
            # Basic validation of metadata structure
            if not isinstance(edu_metadata, list) or (len(edu_metadata) > 0 and not isinstance(edu_metadata[0], dict)):
                print("Warning: Educational CLIP metadata is not a list of dictionaries")
        else:
            print(f"Warning: Educational CLIP metadata not found at {CLIP_EDU_METADATA_PATH}")

        # 3. Merge indices and metadata (with deduplication)
        if original_faiss_index is None and edu_faiss_index is None:
            print("Error: All CLIP FAISS indices failed to load")
            return False

        # If only one index is available, use it directly
        if original_faiss_index is None:
            faiss_index_clip = edu_faiss_index
            metadata_clip = edu_metadata
            print("Using only educational video index and metadata")
        elif edu_faiss_index is None:
            faiss_index_clip = original_faiss_index
            metadata_clip = original_metadata
            print("Using only original index and metadata")
        else:
            # Merge and deduplicate both indices and metadata
            print("Merging both indices and metadata with deduplication...")
            merged_index, valid_indices = merge_and_deduplicate_faiss_indices(original_faiss_index, edu_faiss_index)
            merged_metadata = merge_and_deduplicate_metadata(original_metadata, edu_metadata, valid_indices)

            faiss_index_clip = merged_index
            metadata_clip = merged_metadata
            print("Completed merging and deduplication of indices and metadata!")

        print(f"Final FAISS index contains {faiss_index_clip.ntotal} vectors")
        print(f"Final metadata contains {len(metadata_clip)} entries")

        print("CLIP resources loaded successfully.")
        return True
    except Exception as e:
        print(f"Error loading CLIP resources: {e}")
        # import traceback
        # traceback.print_exc()
        return False

# Attempt to load CLIP resources for reranking
if not load_clip_resources():
    print("Warning: CLIP resources for reranking could not be loaded. Reranking will be disabled.")
    # Depending on requirements, you might choose to exit() if CLIP is critical
    # For now, we'll allow the service to run without CLIP reranking capabilities.
    faiss_index_clip = None # Ensure it's None if loading failed
    metadata_clip = None

# Critical checks after loading (ensure all necessary components are ready)
components_ready = [
    spark is not None,
    query_embedding_pipeline_model is not None,
    faiss_index_qa is not None,
    metadata_pd_df_qa is not None, # Checks if it's assigned (not None)
    faiss_index_yt is not None,
    metadata_pd_df_yt is not None,  # Checks if it's assigned (not None)
    # CLIP components are optional for base functionality, but checked here for completeness if loaded
    # If load_clip_resources() returned False, these will be None and reranking won't occur.
]

if not all(components_ready):
    print("Critical Error: One or more essential components failed to initialize. Service cannot start.")
    if spark and spark.getActiveSession(): spark.stop()
    exit()

# --- EMBEDDING FUNCTION (from build_qa_faiss_version2.py) ---
def embed_query(query_text, pipeline_model, spark_session):
    if pipeline_model is None:
        # This should be caught by startup checks, but good for robustness
        raise RuntimeError("Query embedding pipeline model is not initialized.")
    query_df = spark_session.createDataFrame([("query_id_temp", query_text)], ["id", "text"])
    query_embedded_df = pipeline_model.transform(query_df)
    embedding_vector = query_embedded_df.select(F.col("embedding.embeddings")[0].alias("vector")).collect()[0]["vector"]
    return np.array([embedding_vector]).astype('float32')

def def_query_clip(query_text, model, processor, device):
    """Generates CLIP embedding for a given text query."""
    if model is None or processor is None:
        # This should be caught by startup checks if CLIP resources are critical,
        # or handled gracefully if reranking is optional.
        raise RuntimeError("CLIP model or processor is not initialized.")
    try:
        inputs = processor(text=[query_text], return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            text_embedding = model.get_text_features(**inputs)
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
        return text_embedding.cpu().numpy().astype('float32')
    except Exception as e:
        print(f"Error generating CLIP embedding for query '{query_text}': {e}")
        # import traceback
        # traceback.print_exc()
        return None # Or raise an error

# --- SEARCH FUNCTIONS (from build_qa_faiss_version2.py, minor changes for JSON) ---
# Note: These now return lists of dicts for easier jsonify, instead of DataFrames.
# Also, added explicit type conversion for JSON serializability where pandas might use numpy types.

def search_qa_faiss_service(index, query_embedding_np, k, metadata_df_ref):
    if index is None or metadata_df_ref is None:
        return [] # Return empty list if components missing
    distances, indices = index.search(query_embedding_np, k)
    results = []
    for i, (idx_val, dist) in enumerate(zip(indices[0], distances[0])):
        py_idx = int(idx_val)
        similarity_score = 1 / (1 + dist) if dist >= 0 else float('inf')
        item_data = {"rank": i + 1, "distance": f"{dist:.4f}", "similarity": f"{similarity_score:.4f}"}
        if 0 <= py_idx < len(metadata_df_ref):
            row_data = metadata_df_ref.iloc[py_idx]
            qa_id_val = row_data.get("qa_id")
            item_data["qa_id"] = int(qa_id_val) if isinstance(qa_id_val, np.integer) else str(qa_id_val if pd.notna(qa_id_val) else "N/A")
            item_data["Question"] = str(row_data.get("Question", "N/A"))
            item_data["Answer"] = str(row_data.get("Answer", "N/A"))
            level_val = row_data.get("Level")
            item_data["Level"] = int(level_val) if isinstance(level_val, np.integer) else str(level_val if pd.notna(level_val) else "N/A")
        else:
            item_data["internal_faiss_id"] = py_idx
            item_data["Error"] = "Metadata not found for this index"
        results.append(item_data)
    return results

def search_yt_faiss_service(index, query_embedding_np, k, metadata_df_ref):
    if index is None or metadata_df_ref is None:
        return []
    distances, indices = index.search(query_embedding_np, k)
    results = []
    for i, (idx_val, dist) in enumerate(zip(indices[0], distances[0])):
        py_idx = int(idx_val)
        similarity_score = 1 / (1 + dist) if dist >= 0 else float('inf')
        item_data = {"rank": i + 1, "distance": f"{dist:.4f}", "similarity": f"{similarity_score:.4f}"}
        if 0 <= py_idx < len(metadata_df_ref):
            row_data = metadata_df_ref.iloc[py_idx]
            item_data["video_id"] = str(row_data.get("video_id", f"InternalID-{py_idx}"))
            item_data["title"] = str(row_data.get("title", "(No title information)"))
            item_data["transcript"] = str(row_data.get("transcript", "(No transcript information)"))
        else:
            item_data["internal_faiss_id"] = py_idx
            item_data["Error"] = "Metadata not found for this index"
        results.append(item_data)
    return results

def rerank_youtube_results_with_clip(initial_yt_results, query_text, k_for_rerank,
                                     clip_m, clip_p, clip_idx, clip_meta, device):
    """Reranks YouTube search results using CLIP embeddings and FAISS index from Video_embed.ipynb."""
    if not clip_idx or not clip_meta or not clip_m or not clip_p:
        print("CLIP resources not available, skipping reranking.")
        return initial_yt_results # Return original results if CLIP components are missing

    print(f"Reranking {len(initial_yt_results)} YouTube results with CLIP for query: '{query_text}'.")

    try:
        clip_query_embedding = def_query_clip(query_text, clip_m, clip_p, device)
        if clip_query_embedding is None:
            print("Could not generate CLIP query embedding. Skipping reranking.")
            return initial_yt_results

        # 1. Perform a large search on the CLIP FAISS index
        # k_clip_search determines how many top frame embeddings to retrieve from CLIP index
        # This should be large enough to cover potential matches for videos in initial_yt_results
        num_total_clip_embeddings = clip_idx.ntotal
        k_clip_search = min(num_total_clip_embeddings, max(500, k_for_rerank * 10)) # Heuristic

        print(f"Searching CLIP FAISS index (top {k_clip_search} frames)...")
        distances_clip, indices_clip = clip_idx.search(clip_query_embedding, k_clip_search)

        # 2. Create a mapping from video_id to its best CLIP score found in the search
        video_id_to_best_clip_score = {}
        if indices_clip.size > 0: # Check if any results were returned
            for i in range(indices_clip.shape[1]):
                embedding_idx = indices_clip[0][i]
                score = distances_clip[0][i] # For IndexFlatIP, this is the dot product (similarity)

                if 0 <= embedding_idx < len(clip_meta):
                    video_id = clip_meta[embedding_idx].get('video_id')
                    if video_id:
                        # If video_id already has a score, keep the higher one (max similarity)
                        if video_id not in video_id_to_best_clip_score or score > video_id_to_best_clip_score[video_id]:
                            video_id_to_best_clip_score[video_id] = float(score)
                else:
                    print(f"Warning: embedding_idx {embedding_idx} out of bounds for clip_meta (len {len(clip_meta)}).")

        print(f"Found CLIP scores for {len(video_id_to_best_clip_score)} unique video IDs.")

        # 3. Separate initial results into those found in CLIP search and those not found
        results_to_rerank_with_clip_score = []
        results_not_in_clip_search = []

        for item in initial_yt_results:
            item_video_id = item.get('video_id')
            if item_video_id in video_id_to_best_clip_score:
                item_with_score = item.copy() # Avoid modifying original list items directly
                item_with_score['clip_score'] = video_id_to_best_clip_score[item_video_id]
                results_to_rerank_with_clip_score.append(item_with_score)
            else:
                results_not_in_clip_search.append(item)

        # 4. Sort the rerankable results by their new 'clip_score' in descending order
        if results_to_rerank_with_clip_score:
            results_to_rerank_with_clip_score.sort(key=lambda x: x['clip_score'], reverse=True)
            # Optionally remove the 'clip_score' if not needed in the final output
            # for item in results_to_rerank_with_clip_score: item.pop('clip_score', None)
            print(f"Reranked {len(results_to_rerank_with_clip_score)} videos using CLIP scores.")

        # 5. Combine reranked results with those not found in CLIP search
        final_reranked_results = results_to_rerank_with_clip_score + results_not_in_clip_search

        print("CLIP reranking process completed.")
        return final_reranked_results

    except Exception as e:
        print(f"Error during CLIP reranking for query '{query_text}': {e}")
        # import traceback
        # traceback.print_exc()
        return initial_yt_results # Fallback to original results in case of error

# --- Flask App Setup (from service script) ---
app = Flask(__name__)

@app.route("/")
def home():
    global ngrok_public_url
    if ngrok_public_url:
        return f"""
        <h1>Combined QA & YouTube FAISS Search Service</h1>
        <p>Service is running and accessible via ngrok.</p>
        <p>Use the /search endpoint with a 'query' parameter and optionally a 'k' parameter.</p>
        <p><b>Example:</b> <a href="{ngrok_public_url}/search?query=Explain+photosynthesis&k=3" target="_blank">{ngrok_public_url}/search?query=Explain+photosynthesis&k=3</a></p>
        <p>Current ngrok URL: {ngrok_public_url}</p>
        """
    else:
        return "<h1>Combined QA & YouTube FAISS Search Service</h1><p>Service is starting, ngrok URL not yet available...</p>"

@app.route("/search", methods=["GET"])
def search_endpoint():
    query_text = request.args.get("query")
    try:
        k_results = int(request.args.get("k", default=5))
        if k_results <= 0: k_results = 5
    except ValueError:
        k_results = 5

    if not query_text:
        return jsonify({"error": "Query parameter 'query' is required."}), 400

    print(f"Received search request: Query='{query_text}', k={k_results}")
    try:
        query_vec = embed_query(query_text, query_embedding_pipeline_model, spark)
        if query_vec is None: # Should not happen if embed_query raises error, but defensive
            return jsonify({"error": f"Could not generate embedding for query: '{query_text}'"}), 500

        qa_search_results = search_qa_faiss_service(faiss_index_qa, query_vec, k_results, metadata_pd_df_qa)
        yt_search_results_initial = search_yt_faiss_service(faiss_index_yt, query_vec, k_results, metadata_pd_df_yt)

        # --- Apply CLIP-based reranking to YouTube results ---
        if faiss_index_clip and metadata_clip and clip_model and clip_processor:
            print(f"Attempting CLIP reranking for YouTube results (initial count: {len(yt_search_results_initial)}).")
            yt_search_results_reranked = rerank_youtube_results_with_clip(
                initial_yt_results=yt_search_results_initial,
                query_text=query_text,
                k_for_rerank=k_results, # Pass k_results to help determine CLIP search depth
                clip_m=clip_model,
                clip_p=clip_processor,
                clip_idx=faiss_index_clip,
                clip_meta=metadata_clip,
                device=clip_device
            )
            final_yt_results = yt_search_results_reranked
        else:
            print("CLIP resources not available. Using original YouTube search results.")
            final_yt_results = yt_search_results_initial
        # --- End of CLIP reranking ---

        response_data = {
            "qa_results": qa_search_results,
            "yt_results": final_yt_results # Use the potentially reranked results
        }
        print(f"Search for '{query_text}' (k={k_results}) completed. Found {len(qa_search_results)} QA results, {len(final_yt_results)} YT results.")
        return jsonify(response_data)

    except RuntimeError as e: # Catch error from embed_query if model not ready
        print(f"RuntimeError during search: {e}")
        return jsonify({"error": str(e)}), 500
    except Exception as e:
        print(f"Unexpected error during search for '{query_text}': {e}")
        # Log the full traceback for debugging if possible in a real server
        # import traceback; traceback.print_exc()
        return jsonify({"error": "An unexpected error occurred during search."}), 500

def run_flask_app():
    log = logging.getLogger('werkzeug')
    log.setLevel(logging.ERROR)
    app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)

# --- Main Execution Logic (from service script) ---
if __name__ == "__main__":
    # Initializations are done above. Critical checks ensure components are ready.
    print("\n--- Setting up ngrok tunnel ---")
    ngrok_auth_token = os.environ.get("NGROK_AUTHTOKEN")
    if ngrok_auth_token:
        ngrok.set_auth_token(ngrok_auth_token)
        print("NGROK Authtoken set.")
    else:
        print("NGROK_AUTHTOKEN not found. Using ngrok without token (might have limitations).")

    try:
        public_url_obj = ngrok.connect(5000, hostname="novel-osprey-uncommon.ngrok-free.app")
        ngrok_public_url = public_url_obj.public_url
        print(f' * ngrok tunnel "{ngrok_public_url}" -> "http://127.0.0.1:5000"')
        print(f" * Access the service at: {ngrok_public_url}")
        print(f" * Example search: {ngrok_public_url}/search?query=your_query_here&k=3")

        print("\n--- Starting Flask App in a background thread ---")
        flask_thread = threading.Thread(target=run_flask_app, daemon=True)
        flask_thread.start()
        print("Flask app is running. Colab cell will remain active.")
        print("To stop: interrupt/stop the Colab kernel.")

        while flask_thread.is_alive():
            time.sleep(1)

    except KeyboardInterrupt:
        print("\nShutdown signal received.")
    except Exception as e:
        print(f"Error during ngrok/Flask main loop: {e}")
    finally:
        print("\n--- Shutting down ---")
        if ngrok:
            print("Closing ngrok tunnels...")
            try:
                tunnels = ngrok.get_tunnels()
                for tunnel in tunnels:
                    ngrok.disconnect(tunnel.public_url)
                ngrok.kill()
                print("ngrok tunnels closed.")
            except Exception as ng_e:
                print(f"Error closing ngrok: {ng_e}")
        if spark and spark.getActiveSession():
             print("Closing Spark session...")
             spark.stop()
             print("Spark session closed.")
        print("Script execution finished.")

# Removed: print(search("how does single cell RNA seq work?")) - now handled by API
# Removed: import atexit - spark.stop() is in finally block

Mounted at /content/drive
Google Drive mounted successfully.
Initializing Spark session...
Spark session started.

Preparing query embedding generator...
sent_small_bert_L2_128 download started this may take some time.
Approximate size to download 16.1 MB
[OK!]
Query embedding generator ready.

Loading QA metadata from: /content/drive/My Drive/QA_dataset/qa_combined_embeddings.parquet
Loading YT metadata from: /content/drive/My Drive/Youtube_100M_dataset_v3/combined_video_embeddings.parquet
Temporarily disabled Arrow for QA metadata toPandas.
QA metadata loaded. Records: 680495
Restored Arrow to true for QA.
Temporarily disabled Arrow for YT metadata toPandas.
YT metadata loaded. Records: 22222
Restored Arrow to true for YT.

Loading QA FAISS index from: /content/drive/My Drive/QA_dataset/qa_combined_embeddings.index
Loading YT FAISS index from: /content/drive/My Drive/Youtube_100M_dataset_v3/combined_video_embeddings.index


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


QA FAISS index loaded. Vectors: 680495
YT FAISS index loaded. Vectors: 22222
\n--- Loading CLIP Model and Video_embed Resources ---
Using device for CLIP model: cpu
Loading CLIP processor: openai/clip-vit-base-patch16...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading CLIP model: openai/clip-vit-base-patch16...




CLIP model and processor loaded successfully.
Loading CLIP FAISS index from: /content/drive/MyDrive/video_embeddings_project/video_embeddings.index...
CLIP FAISS index loaded. Vectors: 21300
Loading CLIP metadata from: /content/drive/MyDrive/video_embeddings_project/metadata.pkl...
CLIP metadata loaded. Entries: 21610
CLIP resources loaded successfully.

--- Setting up ngrok tunnel ---
NGROK Authtoken set.
 * ngrok tunnel "https://cs6513edu.ngrok.app" -> "http://127.0.0.1:5000"
 * Access the service at: https://cs6513edu.ngrok.app
 * Example search: https://cs6513edu.ngrok.app/search?query=your_query_here&k=3

--- Starting Flask App in a background thread ---
Flask app is running. Colab cell will remain active.
To stop: interrupt/stop the Colab kernel.
 * Serving Flask app '__main__'
 * Debug mode: off
Received search request: Query='how to bake cake', k=3
Attempting CLIP reranking for YouTube results (initial count: 3).
Reranking 3 YouTube results with CLIP for query: 'how to bake cak