In [2]:
!pip install -q transformers torch torchvision opencv-python-headless faiss-cpu yt-dlp pyspark psutil gputil findspark

!yt-dlp -U

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/173.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m173.3/173.3 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m58.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.7 MB/s[0m eta [3

In [3]:
# ============================================
# FULL PIPELINE WITH CHECKPOINT PROTECTION
# ============================================
import json
import os
import shutil
import time
import traceback
import subprocess
import tempfile
import random
from datetime import datetime
import glob
import re
import gc
import numpy as np
import pandas as pd
import torch
import faiss
import pickle
import psutil
import GPUtil
import cv2
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import io

# Define global constants
MAX_PARTITIONS = 4  # Max number of parallel Spark partitions
BATCH_SIZE = 100  # Number of videos to process in each batch
MAX_FRAMES_PER_VIDEO = 90  # Max frames to extract per video
TEMP_DIR = "/content/temp"  # Temporary directory for downloads
LOCAL_INDEX_PATH = "/content/video_embeddings.index"  # Local index path
DRIVE_PROJECT_DIR = "/content/drive/MyDrive/video_embeddings_project_edu"  # Drive project directory
DRIVE_INDEX_PATH = os.path.join(DRIVE_PROJECT_DIR, "video_embeddings.index")  # Drive index path
DRIVE_METADATA_PATH = os.path.join(DRIVE_PROJECT_DIR, "metadata.pkl")  # Drive metadata path
CHECKPOINT_PATH = "/content/checkpoint.pkl"  # Local checkpoint path
DRIVE_CHECKPOINT_PATH = os.path.join(DRIVE_PROJECT_DIR, "checkpoint.pkl")  # Drive checkpoint path
METADATA_PATH = "/content/metadata.pkl"  # Local metadata path

# Download parameters
DOWNLOAD_RETRIES = 3  # Maximum number of retry attempts
DOWNLOAD_DELAY = 2  # Base delay between retries in seconds

# Make sure directories exist
os.makedirs(TEMP_DIR, exist_ok=True)
os.makedirs(DRIVE_PROJECT_DIR, exist_ok=True)

# Initialize model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

# Broadcast model state dictionary to worker nodes
try:
    bc_model_state = spark.sparkContext.broadcast(model.state_dict())
    bc_processor = spark.sparkContext.broadcast(processor)
except:
    # For local testing without Spark
    bc_model_state = type('obj', (object,), {'value': model.state_dict()})
    bc_processor = type('obj', (object,), {'value': processor})

# ============================================
# UTILITY FUNCTIONS
# ============================================

def print_memory_usage():
    """Print current memory usage"""
    # RAM
    ram = psutil.virtual_memory()
    ram_used = ram.used / (1024 ** 3)  # GB
    ram_total = ram.total / (1024 ** 3)  # GB
    print(f"RAM: {ram_used:.1f}/{ram_total:.1f} GB ({ram.percent:.1f}%)")

    # GPU if available
    try:
        gpus = GPUtil.getGPUs()
        if gpus:
            gpu = gpus[0]
            print(f"GPU: {gpu.memoryUsed}/{gpu.memoryTotal} MB ({gpu.memoryUtil*100:.1f}%)")
    except:
        pass

def load_checkpoint():
    """Load checkpoint from disk"""
    if os.path.exists(CHECKPOINT_PATH):
        try:
            with open(CHECKPOINT_PATH, 'rb') as f:
                return pickle.load(f)
        except Exception as e:
            print(f"Error loading checkpoint: {e}")

    # If local checkpoint doesn't exist or couldn't be loaded, try Drive
    if os.path.exists(DRIVE_CHECKPOINT_PATH):
        try:
            with open(DRIVE_CHECKPOINT_PATH, 'rb') as f:
                checkpoint = pickle.load(f)
            # Save locally for faster access next time
            with open(CHECKPOINT_PATH, 'wb') as f:
                pickle.dump(checkpoint, f)
            return checkpoint
        except Exception as e:
            print(f"Error loading checkpoint from Drive: {e}")

    # If no checkpoint exists, return empty one
    return {"processed_ids": set(), "last_batch": 0, "index_checkpoints": []}

def save_checkpoint(checkpoint):
    """Save checkpoint to disk"""
    try:
        with open(CHECKPOINT_PATH, 'wb') as f:
            pickle.dump(checkpoint, f)
        return True
    except Exception as e:
        print(f"Error saving checkpoint: {e}")
        return False

def load_metadata():
    """Load metadata from disk"""
    if os.path.exists(METADATA_PATH):
        try:
            with open(METADATA_PATH, 'rb') as f:
                return pickle.load(f)
        except Exception as e:
            print(f"Error loading metadata: {e}")

    # If local metadata doesn't exist or couldn't be loaded, try Drive
    if os.path.exists(DRIVE_METADATA_PATH):
        try:
            with open(DRIVE_METADATA_PATH, 'rb') as f:
                metadata = pickle.load(f)
            # Save locally for faster access next time
            with open(METADATA_PATH, 'wb') as f:
                pickle.dump(metadata, f)
            return metadata
        except Exception as e:
            print(f"Error loading metadata from Drive: {e}")

    # If no metadata exists, return empty list
    return []

def save_metadata(metadata):
    """Save metadata to disk"""
    try:
        with open(METADATA_PATH, 'wb') as f:
            pickle.dump(metadata, f)
        return True
    except Exception as e:
        print(f"Error saving metadata: {e}")
        return False

def sync_to_drive():
    """Sync important files to Google Drive"""
    files_to_sync = [
        (CHECKPOINT_PATH, DRIVE_CHECKPOINT_PATH),
        (METADATA_PATH, DRIVE_METADATA_PATH),
        (LOCAL_INDEX_PATH, DRIVE_INDEX_PATH)
    ]

    for src, dst in files_to_sync:
        if os.path.exists(src):
            try:
                shutil.copy(src, dst)
                print(f"Synced {src} to {dst}")
            except Exception as e:
                print(f"Error syncing {src} to {dst}: {e}")

def check_csv(csv_path):
    """Check if CSV file is valid"""
    try:
        df = pd.read_csv(csv_path)
        if 'video_id' in df.columns or 'video_path' in df.columns:
            return True
        else:
            print("CSV must have either 'video_id' or 'video_path' column")
            return False
    except Exception as e:
        print(f"Error checking CSV: {e}")
        return False

def upload_csv_to_drive(local_csv_path):
    """Upload CSV file to Google Drive"""
    drive_csv_path = os.path.join(DRIVE_PROJECT_DIR, os.path.basename(local_csv_path))
    try:
        shutil.copy(local_csv_path, drive_csv_path)
        print(f"Uploaded {local_csv_path} to {drive_csv_path}")
        return drive_csv_path
    except Exception as e:
        print(f"Error uploading CSV to Drive: {e}")
        return None

# ============================================
# VIDEO PROCESSING FUNCTIONS
# ============================================

def download_video(video_id, retry=0):
    """Download a YouTube video with optimized settings and retries"""
    if retry >= DOWNLOAD_RETRIES:
        print(f"Max retries reached for {video_id}")
        return None

    # Add small delay between retries (and small random component)
    if retry > 0:
        time.sleep(DOWNLOAD_DELAY + random.uniform(0, 1))

    temp_dir = tempfile.mkdtemp()
    output_path = f"{temp_dir}/{video_id}.%(ext)s"
    url = f"https://www.youtube.com/watch?v={video_id}"

    cmd = [
        "yt-dlp",
        "-f", "worst",  # Use worst quality to ensure successful download
        "--cookies","/content/youtube_cookies.txt",
        "--downloader", "native",
        "-o", output_path,
        url
    ]

    try:
        print(f"Attempting download of {video_id} (attempt {retry+1}/{DOWNLOAD_RETRIES})")
        result = subprocess.run(cmd, capture_output=True, text=True)

        # Check if there was an error
        if result.returncode != 0:
            print(f"Error downloading {video_id}: {result.stderr[:200]}...")
            # Retry with exponential backoff
            return download_video(video_id, retry + 1)

        # Check if file was actually downloaded
        for file in os.listdir(temp_dir):
            if file.startswith(video_id):
                file_path = os.path.join(temp_dir, file)
                file_size = os.path.getsize(file_path)
                if file_size < 10000:  # Less than 10KB is probably an error
                    print(f"Downloaded file for {video_id} is too small: {file_size} bytes")
                    os.remove(file_path)
                    # Retry with a different format
                    return download_video(video_id, retry + 1)
                print(f"Successfully downloaded {video_id}: {file_size/1024:.1f}KB")
                return file_path

        print(f"No file found for {video_id} after download command completed")
        return download_video(video_id, retry + 1)

    except Exception as e:
        print(f"Exception while downloading {video_id}: {str(e)}")
        return download_video(video_id, retry + 1)

def test_single_download():
    """Test downloading a single video to verify YouTube connectivity"""
    # Use a reliable, short video for testing
    test_video_id = "jNQXAC9IVRw"  # "Me at the zoo" (first YouTube video)
    print(f"Testing YouTube connectivity with video ID: {test_video_id}")

    path = download_video(test_video_id)
    if path and os.path.exists(path):
        print("✓ YouTube download test successful!")
        try:
            os.remove(path)
            os.rmdir(os.path.dirname(path))
        except:
            pass
        return True
    else:
        print("✗ YouTube download test failed.")
        return False

def extract_frames(video_path, max_frames=MAX_FRAMES_PER_VIDEO):
    """Extract frames from a video"""
    frames = []
    try:
        # Open video
        video = cv2.VideoCapture(video_path)

        # Get video properties
        fps = video.get(cv2.CAP_PROP_FPS)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = total_frames / fps if fps > 0 else 0

        # Only extract frames from the first 30 seconds
        max_duration = 30.0  # seconds
        frames_to_extract = min(int(fps * min(max_duration, duration)), max_frames)

        # Calculate frame interval to distribute frames evenly
        if frames_to_extract < max_frames:
            interval = 1
        else:
            interval = total_frames // max_frames

        # Extract frames
        frame_count = 0
        frame_idx = 0

        while frame_count < max_frames and frame_idx < total_frames:
            # Set position
            video.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)

            # Read frame
            success, frame = video.read()
            if not success:
                break

            # Convert from BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            # Convert to PIL Image
            frame = Image.fromarray(frame)

            # Add to frames list
            frames.append(frame)

            # Update counters
            frame_count += 1
            frame_idx += interval

        # Release video
        video.release()

    except Exception as e:
        print(f"Error extracting frames: {str(e)}")

    return frames

def generate_embeddings(frames, model, processor, device):
    """Generate embeddings for frames"""
    embeddings = []

    try:
        # Process in batches to avoid OOM errors
        batch_size = 8
        for i in range(0, len(frames), batch_size):
            batch_frames = frames[i:i+batch_size]

            # Process frames
            with torch.no_grad():
                inputs = processor(images=batch_frames, return_tensors="pt", padding=True).to(device)
                outputs = model.get_image_features(**inputs)

                # Normalize embeddings
                batch_embeddings = outputs / outputs.norm(dim=-1, keepdim=True)

                # Convert to numpy and add to results
                batch_embeddings = batch_embeddings.cpu().numpy()
                embeddings.extend(batch_embeddings)

    except Exception as e:
        print(f"Error generating embeddings: {str(e)}")

    return embeddings

# ============================================
# VIDEO DOWNLOAD AND PROCESSING FUNCTIONS
# ============================================

# First, let's define a function for sequential downloading
def download_videos_sequentially(video_ids):
    """Download videos one by one with no parallelism"""
    print(f"Downloading {len(video_ids)} videos sequentially (no parallelism)...")

    results = {}
    for idx, video_id in enumerate(video_ids):
        print(f"Downloading video {idx+1}/{len(video_ids)}: {video_id}")
        try:
            path = download_video(video_id)
            results[video_id] = path
            if path:
                print(f"✓ Successfully downloaded {video_id}")
            else:
                print(f"✗ Failed to download {video_id}")

            # Add delay between downloads to avoid rate limiting
            time.sleep(1)
        except Exception as e:
            print(f"Error downloading {video_id}: {str(e)}")
            results[video_id] = None

    print(f"Downloaded {sum(1 for p in results.values() if p)} videos successfully")
    return results

# Fix for the generate_embeddings_in_parallel function to properly handle results
def generate_embeddings_in_parallel(video_paths):
    """Generate embeddings in parallel using Spark for already downloaded videos"""
    if not video_paths:
        print("No videos to process")
        return []

    # Convert dict to list of tuples for Spark
    video_items = [(video_id, path) for video_id, path in video_paths.items() if path]

    if not video_items:
        print("No valid video paths")
        return []

    print(f"Generating embeddings for {len(video_items)} videos using Spark...")

    # Define a helper function that processes partition with shared model
    def process_partition(partition):
        # Load model once per partition using broadcasted state dict
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        items = list(partition)
        print(f"Processing partition with {len(items)} videos on {device}")

        # Initialize fresh model from broadcasted state dict
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
        model.load_state_dict(bc_model_state.value)
        processor = bc_processor.value

        # Convert to fp16 if using GPU
        if device.type == 'cuda':
            model.half()

        model = model.to(device)
        model.eval()

        results = []
        for video_id, video_path in items:
            try:
                print(f"\nProcessing video: {video_id}")

                # Skip download - videos are already downloaded
                if not os.path.exists(video_path):
                    print(f"✗ {video_id}: Video file not found at {video_path}")
                    continue

                # Extract frames
                frames = extract_frames(video_path)

                if not frames or len(frames) == 0:
                    print(f"✗ {video_id}: No frames extracted")
                    continue

                # Generate embeddings
                embeddings = generate_embeddings(frames, model, processor, device)

                if not embeddings or len(embeddings) == 0:
                    print(f"✗ {video_id}: No embeddings generated")
                    continue

                # Success! Add to results
                frame_indices = list(range(len(embeddings)))
                print(f"✓ {video_id}: Successfully processed {len(embeddings)} frames")
                results.append((video_id, embeddings, frame_indices))

            except Exception as e:
                print(f"✗ {video_id}: Error - {str(e)}")
                traceback.print_exc()

        # Explicitly clear GPU memory after each partition
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

        return results

    # Create RDD with video items
    num_partitions = min(len(video_items), MAX_PARTITIONS)
    print(f"Using {num_partitions} partitions for {len(video_items)} videos")

    # Create RDD from video items (not just IDs)
    rdd = spark.sparkContext.parallelize(video_items, numSlices=num_partitions)

    # Process using mapPartitions for better efficiency
    start_time = time.time()
    print(f"Starting parallel processing with {num_partitions} partitions")
    results_rdd = rdd.mapPartitions(process_partition)
    all_results = results_rdd.collect()

    # Adaptive flattening - handle both possible Spark return structures
    results = []

    if len(all_results) > 0:
        # Check if we need to flatten (if all_results is a list of lists)
        if isinstance(all_results[0], list):
            # Flatten nested structure
            for partition_results in all_results:
                for result in partition_results:
                    results.append(result)
        else:
            # all_results is already flat
            results = all_results

    elapsed = time.time() - start_time
    print(f"\nProcessed {len(video_items)} videos, got {len(results)} successful results")
    if len(results) > 0:
        print(f"Processing speed: {len(results)/elapsed:.2f} successful videos/second")

    return results

# ============================================
# CHECKPOINT SYSTEM
# ============================================

# Add minimal checkpoint functions
def save_index_checkpoint(index, metadata, batch_idx):
    """Save index checkpoint"""
    if not index or not metadata: return False

    # Create dirs
    os.makedirs(os.path.join(DRIVE_PROJECT_DIR, "checkpoints"), exist_ok=True)
    os.makedirs("/content/checkpoints", exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint_path = f"/content/checkpoints/index_batch{batch_idx}_{timestamp}.faiss"
    metadata_path = f"/content/checkpoints/metadata_batch{batch_idx}_{timestamp}.json"

    try:
        # Save locally
        faiss.write_index(index, checkpoint_path)
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f)

        # Save to Drive
        drive_path = os.path.join(DRIVE_PROJECT_DIR, "checkpoints", f"index_batch{batch_idx}_{timestamp}.faiss")
        drive_metadata = os.path.join(DRIVE_PROJECT_DIR, "checkpoints", f"metadata_batch{batch_idx}_{timestamp}.json")
        shutil.copy(checkpoint_path, drive_path)
        shutil.copy(metadata_path, drive_metadata)

        # Log checkpoint
        checkpoint = load_checkpoint()
        if "index_checkpoints" not in checkpoint:
            checkpoint["index_checkpoints"] = []

        checkpoint["index_checkpoints"].append({
            "batch_idx": batch_idx,
            "timestamp": timestamp,
            "local_path": checkpoint_path,
            "drive_path": drive_path
        })
        save_checkpoint(checkpoint)

        print(f"✓ Checkpoint saved (batch {batch_idx}, {index.ntotal} embeddings)")
        return True
    except Exception as e:
        print(f"✗ Error saving checkpoint: {e}")
        return False

def find_latest_checkpoint():
    """Find latest checkpoint"""
    checkpoint = load_checkpoint()

    # Try loading from checkpoint registry
    if "index_checkpoints" in checkpoint and checkpoint["index_checkpoints"]:
        checkpoints = sorted(checkpoint["index_checkpoints"], key=lambda x: x.get("timestamp", ""), reverse=True)

        for c in checkpoints:
            if os.path.exists(c.get("local_path", "")):
                metadata_path = c["local_path"].replace("index_", "metadata_")
                return c["local_path"], metadata_path, c["batch_idx"]
            elif os.path.exists(c.get("drive_path", "")):
                local_path = f"/content/checkpoints/restored_index.faiss"
                metadata_path = f"/content/checkpoints/restored_metadata.json"

                # Copy to local
                shutil.copy(c["drive_path"], local_path)
                drive_metadata = c["drive_path"].replace("index_", "metadata_")
                shutil.copy(drive_metadata, metadata_path)

                return local_path, metadata_path, c["batch_idx"]

    # Search for files directly
    local_files = glob.glob("/content/checkpoints/index_batch*.faiss")
    if local_files:
        latest = max(local_files, key=os.path.getctime)
        metadata = latest.replace("index_", "metadata_")
        if os.path.exists(metadata):
            match = re.search(r'batch(\d+)', latest)
            batch_idx = int(match.group(1)) if match else 0
            return latest, metadata, batch_idx

    return None, None, 0

def load_index_checkpoint():
    """Load index from checkpoint"""
    index_path, metadata_path, batch_idx = find_latest_checkpoint()

    if index_path and os.path.exists(index_path):
        try:
            # Load checkpoint
            index = faiss.read_index(index_path)
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)

            # Save as main index
            faiss.write_index(index, LOCAL_INDEX_PATH)
            save_metadata(metadata)

            print(f"✓ Loaded checkpoint (batch {batch_idx}, {index.ntotal} embeddings)")
            return index, metadata, batch_idx
        except Exception as e:
            print(f"✗ Error loading checkpoint: {e}")

    # Fall back to regular loading
    index = None
    if os.path.exists(LOCAL_INDEX_PATH):
        index = faiss.read_index(LOCAL_INDEX_PATH)
    elif os.path.exists(DRIVE_INDEX_PATH):
        index = faiss.read_index(DRIVE_INDEX_PATH)
        faiss.write_index(index, LOCAL_INDEX_PATH)

    metadata = load_metadata()
    return index, metadata, 0

# ============================================
# MAIN PROCESSING FUNCTIONS
# ============================================

# Now fix the process_all_videos function to handle embeddings properly
def process_all_videos(csv_path, total_videos=10000, force_reprocess=False):
    """Process all videos with sequential downloading and parallel embedding generation"""
    overall_start_time = time.time()
    print(f"Starting processing of up to {total_videos} videos")

    # Check if file exists
    if not os.path.exists(csv_path):
        print(f"ERROR: File not found: {csv_path}")
        return None, None

    # Load CSV file
    df = pd.read_csv(csv_path)
    # df = df[6000:]
    # Extract video IDs
    if 'video_path' in df.columns:
        df['video_id'] = df['video_path'].apply(lambda x: str(x).split('.')[0])
    elif 'video_id' not in df.columns:
        print("ERROR: CSV must have either 'video_path' or 'video_id' column")
        return None, None

    # Get all video IDs
    all_video_ids = df['video_id'].tolist()[:total_videos]
    print(f"Found {len(all_video_ids)} videos to process")

    # Load checkpoint or start fresh if force_reprocess
    if not force_reprocess:
        checkpoint = load_checkpoint()
        processed_ids = checkpoint["processed_ids"]
        last_batch = checkpoint["last_batch"]
    else:
        print("Force reprocessing enabled - ignoring checkpoint")
        processed_ids = set()
        last_batch = 0

    # Load index from checkpoint or create new
    index, metadata, _ = load_index_checkpoint()

    # Process in batches
    num_batches = (len(all_video_ids) + BATCH_SIZE - 1) // BATCH_SIZE

    for batch_idx in range(last_batch, num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = min(start_idx + BATCH_SIZE, len(all_video_ids))

        # Get videos for this batch (skip already processed ones)
        batch_videos = [vid for vid in all_video_ids[start_idx:end_idx]
                        if vid not in processed_ids]

        if not batch_videos:
            print(f"Batch {batch_idx+1}/{num_batches} already processed, skipping")
            continue

        print(f"\n{'='*60}")
        print(f"PROCESSING BATCH {batch_idx+1}/{num_batches}")
        print(f"{'='*60}")
        print(f"Videos in this batch: {len(batch_videos)}")
        print(f"First few videos: {batch_videos[:5]}")
        print_memory_usage()

        # Process this batch with timing
        batch_start = time.time()
        print(f"Batch {batch_idx+1}: Starting processing at {datetime.now().strftime('%H:%M:%S')}")

        # STEP 1: Download videos sequentially
        video_paths = download_videos_sequentially(batch_videos)

        # STEP 2: Generate embeddings in parallel with Spark
        batch_results = generate_embeddings_in_parallel(video_paths)

        # Cleanup downloaded videos
        for path in video_paths.values():
            if path and os.path.exists(path):
                try:
                    os.remove(path)
                    parent_dir = os.path.dirname(path)
                    if parent_dir != TEMP_DIR and os.path.exists(parent_dir):
                        os.rmdir(parent_dir)
                except:
                    pass

        batch_time = time.time() - batch_start
        print(f"Batch {batch_idx+1}: Processed {len(batch_results)}/{len(batch_videos)} videos in {batch_time:.2f}s")

        # Mark ALL videos as processed
        for video_id in batch_videos:
            processed_ids.add(video_id)

        # Create batch metadata and collect embeddings
        batch_metadata = []
        all_embeddings = []

        # Timing embedding aggregation
        t_agg = time.time()

        # Process results with safety check
        for result in batch_results:
            try:
                # Validate result structure
                if not isinstance(result, tuple) or len(result) != 3:
                    print(f"WARNING: Skipping result with invalid structure (expected 3 elements)")
                    continue

                video_id, embeddings, frame_indices = result

                # Add embeddings
                for i, embedding in enumerate(embeddings):
                    all_embeddings.append(embedding)
                    batch_metadata.append({
                        'video_id': video_id,
                        'frame_idx': frame_indices[i],
                        'embedding_idx': len(metadata) + len(batch_metadata) - 1
                    })
            except Exception as e:
                print(f"Error processing result: {e}")
                continue

        print(f"Aggregated {len(all_embeddings)} embeddings in {time.time()-t_agg:.2f}s")

        # Add to FAISS index
        if all_embeddings:
            t_index = time.time()
            print(f"Adding {len(all_embeddings)} new embeddings to index")
            embeddings_array = np.array(all_embeddings).astype('float32')

            if index is None:
                # Create index
                dim = embeddings_array.shape[1]
                print(f"Creating new FAISS index with dimension {dim}")
                index = faiss.IndexFlatIP(dim)

            # Add to index
            index.add(embeddings_array)

            # Save index locally
            print(f"Saving local index with {index.ntotal} total embeddings")
            faiss.write_index(index, LOCAL_INDEX_PATH)

            # Update metadata
            metadata.extend(batch_metadata)
            save_metadata(metadata)

            print(f"Index operations took {time.time()-t_index:.2f}s")

            # Save checkpoint
            save_index_checkpoint(index, metadata, batch_idx)

            # Sync to Drive
            t_sync = time.time()
            print("Syncing index and metadata to Drive...")
            sync_to_drive()
            print(f"Sync to Drive completed in {time.time()-t_sync:.2f}s")

        # Save processing checkpoint
        checkpoint = {
            "processed_ids": processed_ids,
            "last_batch": batch_idx + 1
        }
        save_checkpoint(checkpoint)
        print(f"Checkpoint saved at batch {batch_idx+1}")

        # Clean up memory
        del all_embeddings
        del batch_results
        if 'embeddings_array' in locals():
            del embeddings_array

        # Force garbage collection
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

        # Overall progress
        overall_progress = len(processed_ids) / len(all_video_ids) * 100
        elapsed = time.time() - overall_start_time
        eta = (elapsed / (batch_idx - last_batch + 1)) * (num_batches - batch_idx - 1) if batch_idx > last_batch else 0

        print(f"\nOverall progress: {len(processed_ids)}/{len(all_video_ids)} videos ({overall_progress:.1f}%)")
        print(f"Elapsed time: {elapsed/3600:.2f} hours")
        print(f"Estimated remaining time: {eta/3600:.2f} hours")
        print(f"{'='*60}\n")

    # Final sync to Drive
    sync_to_drive()

    # Load final index to return
    if os.path.exists(LOCAL_INDEX_PATH):
        index = faiss.read_index(LOCAL_INDEX_PATH)
        print(f"Final index size: {index.ntotal} embeddings")
    else:
        index = None

    total_time = time.time() - overall_start_time
    print(f"Total processing time: {total_time/3600:.2f} hours")
    return index, metadata

# Update the run_pipeline function to use our new approach
def run_pipeline(csv_path=None, batch_size=None, video_limit=None, force_reprocess=False):
    """Run the pipeline with the specified parameters"""
    print("\n" + "="*70)
    print("VIDEO EMBEDDING PIPELINE WITH HYBRID APPROACH")
    print("="*70)

    # Check GPU
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        if torch.cuda.is_available():
            print(f"GPU: {torch.cuda.get_device_name(0)}")
            print(f"Optimizations: FP16=True, Fast CLIP model=True")
            print(f"Frame extraction: First 90 seconds only (max {MAX_FRAMES_PER_VIDEO} frames)")
    except:
        print("Could not detect GPU")

    # Initial memory usage
    print("\nInitial memory usage:")
    print_memory_usage()

    # Override batch size if specified
    if batch_size is not None:
        global BATCH_SIZE
        BATCH_SIZE = batch_size
        print(f"Using custom batch size: {BATCH_SIZE}")

    # Test video download to verify YouTube connectivity
    test_download_result = test_single_download()
    if not test_download_result:
        print("\nWARNING: Initial download test failed. Processing may encounter issues.")
        user_input = input("Do you want to continue anyway? (y/n): ")
        if user_input.lower() != 'y':
            print("Aborting pipeline.")
            return

    # Determine CSV path
    if csv_path is None:
        # Check if we have a CSV already in Google Drive
        drive_csv_path = os.path.join(DRIVE_PROJECT_DIR, "all_videos.csv")
        local_csv_path = '/content/all_videos.csv'

        csv_path_to_use = None

        # First check Drive
        if os.path.exists(drive_csv_path):
            print(f"Found CSV in Google Drive: {drive_csv_path}")
            csv_path_to_use = drive_csv_path
        # Then check local
        elif os.path.exists(local_csv_path):
            print(f"Found CSV locally, uploading to Drive")
            drive_csv_path = upload_csv_to_drive(local_csv_path)
            if drive_csv_path:
                csv_path_to_use = drive_csv_path
            else:
                csv_path_to_use = local_csv_path
        else:
            print("No CSV file found. Please upload a CSV file with video IDs.")
            return
    else:
        csv_path_to_use = csv_path

    # Run processing
    if csv_path_to_use and check_csv(csv_path_to_use):
        total_videos = video_limit if video_limit else 10000
        print(f"\nStarting processing of up to {total_videos} videos")
        print(f"Using batch size: {BATCH_SIZE}")
        print(f"Using hybrid approach: Sequential downloads + Parallel embeddings")
        print(f"Force reprocessing: {force_reprocess}")
        print(f"All results will be saved to both local storage and Google Drive")

        index, metadata = process_all_videos(csv_path_to_use, total_videos=total_videos,
                                           force_reprocess=force_reprocess)

        # Final cleanup
        del index
        del metadata
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

        print("\nProcessing complete! FAISS index created.")
        print(f"Index path: {DRIVE_INDEX_PATH}")
        print(f"Metadata path: {DRIVE_METADATA_PATH}")
        print("\nTo search: search_by_text('your query here')")

        # Final memory usage
        print("\nFinal memory usage:")
        print_memory_usage()
    else:
        print("\nPlease upload a valid CSV file with video IDs before running.")

# ============================================
# SEARCH FUNCTIONALITY
# ============================================

def search_by_text(query, top_k=5):
    """Search for videos by text query"""
    # Load index and metadata
    if not os.path.exists(LOCAL_INDEX_PATH):
        if os.path.exists(DRIVE_INDEX_PATH):
            shutil.copy(DRIVE_INDEX_PATH, LOCAL_INDEX_PATH)
        else:
            print("ERROR: Index not found")
            return []

    # Load metadata
    metadata = load_metadata()
    if not metadata:
        print("ERROR: Metadata not found")
        return []

    # Load index
    index = faiss.read_index(LOCAL_INDEX_PATH)

    # Load model and processor
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

    # Generate text embedding
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        inputs = processor(text=[query], return_tensors="pt", padding=True).to(device)
        text_embedding = model.get_text_features(**inputs)
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
        text_embedding = text_embedding.cpu().numpy().astype('float32')

    # Search index
    D, I = index.search(text_embedding, top_k)

    # Prepare results
    results = []
    for i, (score, idx) in enumerate(zip(D[0], I[0])):
        if idx < len(metadata):
            entry = metadata[idx]
            video_id = entry['video_id']
            results.append({
                'rank': i+1,
                'video_id': video_id,
                'frame_idx': entry['frame_idx'],
                'score': float(score),
                'link': f"https://www.youtube.com/watch?v={video_id}"
            })

    return results

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/4.10k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/599M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/599M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/961k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

In [4]:

from pyspark.sql import SparkSession
import findspark
import warnings

# Initialize Spark
def initialize_spark():
    """Initialize Spark session"""
    print("Initializing Spark...")

    try:
        # Try to initialize findspark to locate Spark installation
        try:
            findspark.init()
            print(" Successfully initialized findspark")
        except:
            print(" Failed to initialize findspark, will try direct SparkSession creation")
            pass

        # Create Spark session
        spark = SparkSession.builder \
            .appName("VideoEmbeddingPipeline") \
            .config("spark.driver.memory", "4g") \
            .config("spark.executor.memory", "4g") \
            .config("spark.dynamicAllocation.enabled", "false") \
            .config("spark.driver.maxResultSize", "2g") \
            .getOrCreate()

        # Suppress Spark warnings
        warnings.filterwarnings("ignore")
        spark.sparkContext.setLogLevel("ERROR")

        # Set global variable
        globals()['spark'] = spark

        # Print successful initialization
        print(f" Successfully initialized Spark {spark.version}")
        print(f" Using {spark.sparkContext.defaultParallelism} cores")

        return spark

    except Exception as e:
        print(f" Error initializing Spark: {e}")
        print("Falling back to single-node processing...")
        return spark

def update_checkpoint_registry():
    """Update checkpoint registry to include files in the checkpoints directory"""
    print("\n===== UPDATING CHECKPOINT REGISTRY =====")

    # Load existing checkpoint
    print(f"Loading existing checkpoint from: {CHECKPOINT_PATH}")
    try:
        with open(CHECKPOINT_PATH, 'rb') as f:
            checkpoint = pickle.load(f)

        print(f"Current checkpoint info:")
        print(f"  Last batch: {checkpoint.get('last_batch', 'Not found')}")
        print(f"  Processed videos: {len(checkpoint.get('processed_ids', set()))}")
        print(f"  Index checkpoints: {len(checkpoint.get('index_checkpoints', []))}")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        print("Creating new checkpoint...")
        checkpoint = {
            "processed_ids": set(),
            "last_batch": 0,
            "index_checkpoints": []
        }

    # Find all FAISS files in checkpoints directory
    drive_checkpoints_dir = os.path.join(DRIVE_PROJECT_DIR, "checkpoints")
    print(f"\nScanning for checkpoint files in: {drive_checkpoints_dir}")

    if not os.path.exists(drive_checkpoints_dir):
        print(f"  ✗ Checkpoint directory not found")
        return checkpoint

    faiss_files = glob.glob(os.path.join(drive_checkpoints_dir, "index_batch*.faiss"))
    print(f"  Found {len(faiss_files)} FAISS checkpoint files")

    # Track highest batch number
    highest_batch = checkpoint.get("last_batch", 0)

    # Process each FAISS file
    for faiss_file in faiss_files:
        filename = os.path.basename(faiss_file)
        match = re.search(r'batch(\d+)_(\d{8}_\d{6})', filename)
        if match:
            batch_idx = int(match.group(1))
            timestamp = match.group(2)

            # Update highest batch if needed
            highest_batch = max(highest_batch, batch_idx)

            # Check if this file is already registered
            file_exists = False
            for entry in checkpoint.get("index_checkpoints", []):
                if os.path.basename(entry.get("drive_path", "")) == filename:
                    file_exists = True
                    break

            if not file_exists:
                # Find corresponding metadata file
                metadata_file = faiss_file.replace("index_", "metadata_")

                if os.path.exists(metadata_file):
                    # Create local paths
                    local_faiss = os.path.join("/content/checkpoints", filename)
                    local_metadata = os.path.join("/content/checkpoints", os.path.basename(metadata_file))

                    # Add to checkpoint registry
                    entry = {
                        "batch_idx": batch_idx,
                        "timestamp": timestamp,
                        "local_path": local_faiss,
                        "drive_path": faiss_file
                    }

                    checkpoint.setdefault("index_checkpoints", []).append(entry)
                    print(f"  Added checkpoint for batch {batch_idx} to registry")

                    # Try to load metadata to get processed IDs
                    try:
                        with open(metadata_file, 'r') as f:
                            metadata = json.load(f)

                        # Extract video IDs from metadata
                        for item in metadata:
                            if 'video_id' in item:
                                checkpoint.setdefault("processed_ids", set()).add(item['video_id'])
                    except Exception as e:
                        print(f"  Error loading metadata from {metadata_file}: {e}")

    # Update last_batch if needed
    if highest_batch > checkpoint.get("last_batch", 0):
        print(f"\nUpdating last_batch from {checkpoint.get('last_batch', 0)} to {highest_batch}")
        checkpoint["last_batch"] = highest_batch

    # Sort checkpoints by timestamp
    checkpoint["index_checkpoints"] = sorted(
        checkpoint.get("index_checkpoints", []),
        key=lambda x: x.get("timestamp", ""),
        reverse=True
    )

    # Save updated checkpoint
    print("\nSaving updated checkpoint...")
    with open(CHECKPOINT_PATH, 'wb') as f:
        pickle.dump(checkpoint, f)

    # Save to Drive
    with open(DRIVE_CHECKPOINT_PATH, 'wb') as f:
        pickle.dump(checkpoint, f)

    print(f"Checkpoint updated and saved!")
    print(f"Last batch: {checkpoint['last_batch']}")
    print(f"Processed videos: {len(checkpoint.get('processed_ids', set()))}")
    print(f"Index checkpoints: {len(checkpoint['index_checkpoints'])}")

    return checkpoint

# Function to run pipeline with updated checkpoint
def run_pipeline_with_updated_checkpoint(csv_path=None, batch_size=None, video_limit=None, force_reprocess=False):
    """Run the pipeline with an updated checkpoint registry"""
    print("\n" + "="*70)
    print("VIDEO EMBEDDING PIPELINE WITH CHECKPOINT UPDATE")
    print("="*70)

    initialize_spark()

    # Create checkpoints directory if it doesn't exist
    os.makedirs("/content/checkpoints", exist_ok=True)
    os.makedirs(os.path.join(DRIVE_PROJECT_DIR, "checkpoints"), exist_ok=True)

    # Update checkpoint registry
    print("Checking and updating checkpoint registry...")
    checkpoint = update_checkpoint_registry()

    print(f"\nCheckpoint registry updated. Pipeline will resume from batch {checkpoint['last_batch'] + 1}")
    print(f"Would you like to continue? (y/n)")

    choice = input().lower()
    if not choice.startswith('y'):
        print("Aborting pipeline.")
        return

    # Continue with normal pipeline
    return run_pipeline(csv_path, batch_size, video_limit, force_reprocess)

In [5]:
run_pipeline_with_updated_checkpoint('/content/drive/MyDrive/video_embeddings_project_edu/all_videos.csv', batch_size=300, video_limit=6000, force_reprocess=False)


VIDEO EMBEDDING PIPELINE WITH CHECKPOINT UPDATE
Initializing Spark...
 Successfully initialized findspark
 Successfully initialized Spark 3.5.1
 Using 2 cores
Checking and updating checkpoint registry...

===== UPDATING CHECKPOINT REGISTRY =====
Loading existing checkpoint from: /content/checkpoint.pkl
Error loading checkpoint: [Errno 2] No such file or directory: '/content/checkpoint.pkl'
Creating new checkpoint...

Scanning for checkpoint files in: /content/drive/MyDrive/video_embeddings_project_edu/checkpoints
  Found 7 FAISS checkpoint files

Updating last_batch from 0 to 6

Saving updated checkpoint...
Checkpoint updated and saved!
Last batch: 6
Processed videos: 0
Index checkpoints: 0

Checkpoint registry updated. Pipeline will resume from batch 7
Would you like to continue? (y/n)
y

VIDEO EMBEDDING PIPELINE WITH HYBRID APPROACH
Using device: cuda
GPU: Tesla T4
Optimizations: FP16=True, Fast CLIP model=True
Frame extraction: First 90 seconds only (max 90 frames)

Initial memory 

KeyboardInterrupt: 

In [6]:
def search_by_text(query, top_k=5):
    """Search for videos by text query"""
    # Load index and metadata
    if not os.path.exists(LOCAL_INDEX_PATH):
        if os.path.exists(DRIVE_INDEX_PATH):
            shutil.copy(DRIVE_INDEX_PATH, LOCAL_INDEX_PATH)
        else:
            print("ERROR: Index not found")
            return []

    # Load metadata
    metadata = load_metadata()
    if not metadata:
        print("ERROR: Metadata not found")
        return []

    # Load index
    index = faiss.read_index(LOCAL_INDEX_PATH)

    # Load model and processor
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

    # Generate text embedding
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        inputs = processor(text=[query], return_tensors="pt", padding=True).to(device)
        text_embedding = model.get_text_features(**inputs)
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
        text_embedding = text_embedding.cpu().numpy().astype('float32')

    # Search index
    D, I = index.search(text_embedding, top_k)

    # Prepare results
    results = []
    for i, (score, idx) in enumerate(zip(D[0], I[0])):
        if idx < len(metadata):
            entry = metadata[idx]
            video_id = entry['video_id']
            results.append({
                'rank': i+1,
                'video_id': video_id,
                'frame_idx': entry['frame_idx'],
                'score': float(score),
                'link': f"https://www.youtube.com/watch?v={video_id}"
            })

    return results

In [7]:
search_by_text("how to measure blood pressure", top_k=5)

[{'rank': 1,
  'video_id': 'Li4oGhfKmDQ.mp4',
  'frame_idx': 66,
  'score': 0.3678836226463318,
  'link': 'https://www.youtube.com/watch?v=Li4oGhfKmDQ.mp4'},
 {'rank': 2,
  'video_id': 'Li4oGhfKmDQ.mp4',
  'frame_idx': 65,
  'score': 0.3602183759212494,
  'link': 'https://www.youtube.com/watch?v=Li4oGhfKmDQ.mp4'},
 {'rank': 3,
  'video_id': 'WyMTF_m-DzY.mp4',
  'frame_idx': 53,
  'score': 0.34256410598754883,
  'link': 'https://www.youtube.com/watch?v=WyMTF_m-DzY.mp4'},
 {'rank': 4,
  'video_id': 'vi2ZGIn4AAo.mp4',
  'frame_idx': 63,
  'score': 0.3382461369037628,
  'link': 'https://www.youtube.com/watch?v=vi2ZGIn4AAo.mp4'},
 {'rank': 5,
  'video_id': 'WyMTF_m-DzY.mp4',
  'frame_idx': 46,
  'score': 0.33516231179237366,
  'link': 'https://www.youtube.com/watch?v=WyMTF_m-DzY.mp4'}]

In [11]:
search_by_text("explain periodic table", top_k=5)

[{'rank': 1,
  'video_id': 'rz4Dd1I_fX0.mp4',
  'frame_idx': 79,
  'score': 0.3329148590564728,
  'link': 'https://www.youtube.com/watch?v=rz4Dd1I_fX0.mp4'},
 {'rank': 2,
  'video_id': 'rz4Dd1I_fX0.mp4',
  'frame_idx': 80,
  'score': 0.3288887143135071,
  'link': 'https://www.youtube.com/watch?v=rz4Dd1I_fX0.mp4'},
 {'rank': 3,
  'video_id': 'rz4Dd1I_fX0.mp4',
  'frame_idx': 64,
  'score': 0.32498079538345337,
  'link': 'https://www.youtube.com/watch?v=rz4Dd1I_fX0.mp4'},
 {'rank': 4,
  'video_id': 'rz4Dd1I_fX0.mp4',
  'frame_idx': 19,
  'score': 0.3211629092693329,
  'link': 'https://www.youtube.com/watch?v=rz4Dd1I_fX0.mp4'},
 {'rank': 5,
  'video_id': 'rz4Dd1I_fX0.mp4',
  'frame_idx': 61,
  'score': 0.31860077381134033,
  'link': 'https://www.youtube.com/watch?v=rz4Dd1I_fX0.mp4'}]